From edd53d817d81e39d4805e8233ce225aca3fc2912 Mon Sep 17 00:00:00 2001 From: ElektrikSpark Date: Sat, 19 Oct 2024 03:59:21 +0000 Subject: [PATCH 1/5] feat: added protein splitting, call to netMHCpan --- apps/fastapi/pyproject.toml | 1 + apps/fastapi/requirements.txt | 22 +- .../app/api/api_v1/endpoints/prediction.py | 3 + apps/fastapi/src/app/core/utils.py | 69 + .../crud/crud_conformational_b_prediction.py | 14 +- .../src/app/crud/crud_linear_b_prediction.py | 14 +- .../src/app/crud/crud_mhc_i_prediction.py | 14 +- .../src/app/crud/crud_mhc_ii_prediction.py | 14 +- .../schemas/conformational_b_prediction.py | 10 +- .../src/app/schemas/linear_b_prediction.py | 4 +- .../src/app/schemas/mhc_i_prediction.py | 10 +- .../src/app/schemas/mhc_ii_prediction.py | 10 +- .../src/app/services/background_tasks.py | 50 +- apps/fastapi/src/app/services/inference.py | 107 +- apps/fastapi/src/app/services/pipeline.py | 4 +- apps/fastapi/src/app/services/postprocess.py | 106 +- apps/fastapi/src/app/services/preprocess.py | 11 + apps/fastapi/uv.lock | 63 + notebooks/SaprotHub.ipynb | 3189 ----------------- notebooks/protein_language_modeling.ipynb | 2554 ------------- 20 files changed, 467 insertions(+), 5802 deletions(-) delete mode 100644 notebooks/SaprotHub.ipynb delete mode 100644 notebooks/protein_language_modeling.ipynb diff --git a/apps/fastapi/pyproject.toml b/apps/fastapi/pyproject.toml index df546a2..5f077bb 100644 --- a/apps/fastapi/pyproject.toml +++ b/apps/fastapi/pyproject.toml @@ -9,6 +9,7 @@ authors = [ readme = "README.md" requires-python = ">=3.12" dependencies = [ + "aioboto3>=13.2.0", "aiofiles>=24.1.0", "boto3>=1.35.29", "fastapi[standard]>=0.115.0", diff --git a/apps/fastapi/requirements.txt b/apps/fastapi/requirements.txt index 5f346a3..ee92731 100644 --- a/apps/fastapi/requirements.txt +++ b/apps/fastapi/requirements.txt @@ -1,11 +1,21 @@ # This file was autogenerated by uv via the following command: # uv pip compile pyproject.toml -o requirements.txt -aiofiles==24.1.0 +aioboto3==13.2.0 # via fastapi-epitope-prediction (pyproject.toml) +aiobotocore==2.15.2 + # via aioboto3 +aiofiles==24.1.0 + # via + # fastapi-epitope-prediction (pyproject.toml) + # aioboto3 aiohappyeyeballs==2.4.3 # via aiohttp aiohttp==3.10.8 - # via realtime + # via + # aiobotocore + # realtime +aioitertools==0.12.0 + # via aiobotocore aiosignal==1.3.1 # via aiohttp annotated-types==0.7.0 @@ -19,9 +29,12 @@ anyio==4.6.0 attrs==24.2.0 # via aiohttp boto3==1.35.32 - # via fastapi-epitope-prediction (pyproject.toml) + # via + # fastapi-epitope-prediction (pyproject.toml) + # aiobotocore botocore==1.35.32 # via + # aiobotocore # boto3 # s3transfer certifi==2024.8.30 @@ -77,6 +90,7 @@ httptools==0.6.1 # via uvicorn httpx==0.27.2 # via + # fastapi-epitope-prediction (pyproject.toml) # fastapi # gotrue # openai @@ -235,5 +249,7 @@ websockets==13.1 # via # realtime # uvicorn +wrapt==1.16.0 + # via aiobotocore yarl==1.13.1 # via aiohttp diff --git a/apps/fastapi/src/app/api/api_v1/endpoints/prediction.py b/apps/fastapi/src/app/api/api_v1/endpoints/prediction.py index 9f193c0..e97263d 100644 --- a/apps/fastapi/src/app/api/api_v1/endpoints/prediction.py +++ b/apps/fastapi/src/app/api/api_v1/endpoints/prediction.py @@ -97,7 +97,9 @@ async def create_mhc_i_prediction( process_and_update_prediction, job_id=job.id, sequence=prediction_in.sequence, + alleles=prediction_in.alleles, prediction_type="mhc-i", + user_id=user.id, db=db, ) @@ -124,6 +126,7 @@ async def create_mhc_ii_prediction( process_and_update_prediction, job_id=job.id, sequence=prediction_in.sequence, + alleles=prediction_in.alleles, prediction_type="mhc-ii", db=db, ) diff --git a/apps/fastapi/src/app/core/utils.py b/apps/fastapi/src/app/core/utils.py index 302400c..988f768 100644 --- a/apps/fastapi/src/app/core/utils.py +++ b/apps/fastapi/src/app/core/utils.py @@ -1,9 +1,11 @@ import csv +import io import logging import re from io import StringIO from typing import List, Optional, Type, TypeVar +import aioboto3 import boto3 import httpx from fastapi import HTTPException @@ -169,3 +171,70 @@ async def fetch_pdb_data(pdb_id: str, chain: Optional[str] = None) -> dict: status_code=response.status_code, detail=f"Error fetching PDB data: {response.status_code}", ) + + +def split_protein_sequence( + protein_sequence: str, min_length: int, max_length: int +) -> List[str]: + """ + Splits a protein sequence into peptides based on the provided min and max lengths. + """ + peptides = [] + for length in range(min_length, max_length + 1): + peptides.extend( + [ + protein_sequence[i : i + length] + for i in range(len(protein_sequence) - length + 1) + ] + ) + return peptides + + +def get_default_peptide_lengths(prediction_type: str): + if prediction_type == "mhc-i": + return 8, 11 + elif prediction_type == "mhc-ii": + return 13, 25 + # Add other prediction types if needed + return 8, 11 # Fallback default lengths + + +def generate_csv_key( + user_id: str, job_id: str, timestamp: str, prediction_type: str +) -> str: + """ + Generates a unique S3 key for the CSV file based on user ID, job ID, and timestamp. + """ + return f"predictions/{user_id}/{job_id}_{prediction_type}_{timestamp}.csv" + + +async def upload_csv_to_s3(results: List[T], s3_key: str): + """ + Uploads the processed results to S3 as a CSV file, using the Pydantic schema to generate the columns. + :param results: List of Pydantic models (MhcIPredictionResult, MhcIIPredictionResult, etc.) + :param s3_key: The key (path) where the CSV will be stored in S3. + """ + if not results: + raise HTTPException(status_code=400, detail="No results to upload") + + # Dynamically get the field names (columns) from the Pydantic schema + fieldnames = list(results[0].model_fields.keys()) + + output = io.StringIO() + writer = csv.DictWriter(output, fieldnames=fieldnames, quoting=csv.QUOTE_NONNUMERIC) + writer.writeheader() + + # Convert results (Pydantic models) to dictionaries for CSV writing + for result in results: + writer.writerow(result.model_dump()) + + csv_content = output.getvalue() + + # Upload the CSV to S3 + async with aioboto3.client("s3", region_name=settings.AWS_REGION) as s3_client: + try: + await s3_client.put_object( + Bucket=settings.S3_BUCKET_NAME, Key=s3_key, Body=csv_content + ) + except Exception: + raise HTTPException(status_code=500, detail="Failed to upload CSV to S3") diff --git a/apps/fastapi/src/app/crud/crud_conformational_b_prediction.py b/apps/fastapi/src/app/crud/crud_conformational_b_prediction.py index de1fcc7..3af0cfd 100644 --- a/apps/fastapi/src/app/crud/crud_conformational_b_prediction.py +++ b/apps/fastapi/src/app/crud/crud_conformational_b_prediction.py @@ -50,7 +50,12 @@ async def create( return self.model(**created_prediction) async def update_result( - self, db: AsyncClient, *, job_id: str, result: List[PredictionResult] + self, + db: AsyncClient, + *, + job_id: str, + result: List[PredictionResult], + csv_download_url: str, ) -> ConformationalBPrediction: prediction = await self.get_by_job_id(db=db, job_id=job_id) if not prediction: @@ -59,7 +64,12 @@ async def update_result( # Update the result field updated_prediction = ( await db.table(self.model.table_name) - .update({"result": [res.model_dump() for res in result]}) + .update( + { + "result": [res.model_dump() for res in result], + "csv_download_url": csv_download_url, + } + ) .eq("job_id", job_id) .execute() ) diff --git a/apps/fastapi/src/app/crud/crud_linear_b_prediction.py b/apps/fastapi/src/app/crud/crud_linear_b_prediction.py index 08369ed..a94bc9c 100644 --- a/apps/fastapi/src/app/crud/crud_linear_b_prediction.py +++ b/apps/fastapi/src/app/crud/crud_linear_b_prediction.py @@ -50,7 +50,12 @@ async def create( return self.model(**created_prediction) async def update_result( - self, db: AsyncClient, *, job_id: str, result: List[LBPredictionResult] + self, + db: AsyncClient, + *, + job_id: str, + result: List[LBPredictionResult], + csv_download_url: str, ) -> LinearBPrediction: prediction = await self.get_by_job_id(db=db, job_id=job_id) if not prediction: @@ -59,7 +64,12 @@ async def update_result( # Update the result field updated_prediction = ( await db.table(self.model.table_name) - .update({"result": [res.model_dump() for res in result]}) + .update( + { + "result": [res.model_dump() for res in result], + "csv_download_url": csv_download_url, + } + ) .eq("job_id", job_id) .execute() ) diff --git a/apps/fastapi/src/app/crud/crud_mhc_i_prediction.py b/apps/fastapi/src/app/crud/crud_mhc_i_prediction.py index c02659d..ce080ce 100644 --- a/apps/fastapi/src/app/crud/crud_mhc_i_prediction.py +++ b/apps/fastapi/src/app/crud/crud_mhc_i_prediction.py @@ -50,7 +50,12 @@ async def create( return self.model(**created_prediction) async def update_result( - self, db: AsyncClient, *, job_id: str, result: List[MhcIPredictionResult] + self, + db: AsyncClient, + *, + job_id: str, + result: List[MhcIPredictionResult], + csv_download_url: str, ) -> MhcIPrediction: prediction = await self.get_by_job_id(db=db, job_id=job_id) if not prediction: @@ -59,7 +64,12 @@ async def update_result( # Update the result field updated_prediction = ( await db.table(self.model.table_name) - .update({"result": [res.model_dump() for res in result]}) + .update( + { + "result": [res.model_dump() for res in result], + "csv_download_url": csv_download_url, + } + ) .eq("job_id", job_id) .execute() ) diff --git a/apps/fastapi/src/app/crud/crud_mhc_ii_prediction.py b/apps/fastapi/src/app/crud/crud_mhc_ii_prediction.py index 65fc366..c9a95ec 100644 --- a/apps/fastapi/src/app/crud/crud_mhc_ii_prediction.py +++ b/apps/fastapi/src/app/crud/crud_mhc_ii_prediction.py @@ -50,7 +50,12 @@ async def create( return self.model(**created_prediction) async def update_result( - self, db: AsyncClient, *, job_id: str, result: List[MhcIIPredictionResult] + self, + db: AsyncClient, + *, + job_id: str, + result: List[MhcIIPredictionResult], + csv_download_url: str, ) -> MhcIIPrediction: prediction = await self.get_by_job_id(db=db, job_id=job_id) if not prediction: @@ -59,7 +64,12 @@ async def update_result( # Update the result field updated_prediction = ( await db.table(self.model.table_name) - .update({"result": [res.model_dump() for res in result]}) + .update( + { + "result": [res.model_dump() for res in result], + "csv_download_url": csv_download_url, + } + ) .eq("job_id", job_id) .execute() ) diff --git a/apps/fastapi/src/app/schemas/conformational_b_prediction.py b/apps/fastapi/src/app/schemas/conformational_b_prediction.py index 01e52a0..b459b79 100644 --- a/apps/fastapi/src/app/schemas/conformational_b_prediction.py +++ b/apps/fastapi/src/app/schemas/conformational_b_prediction.py @@ -16,11 +16,11 @@ class PredictionResult(BaseModel): AA: str Epitope_score: float N_glyco_label: int - Hydrophilicity: Optional[float] = None # fix - Charge: Optional[int] = None # fix - ASA: Optional[float] = None - RSA: Optional[float] = None - B_Factor: Optional[float] = None + Hydrophilicity: Optional[float] = Field(default=None) + Charge: Optional[int] = Field(default=None) + ASA: Optional[float] = Field(default=None) + RSA: Optional[float] = Field(default=None) + B_Factor: Optional[float] = Field(default=None) class ConformationalBPredictionCreate(CreateBase, JobMixin): diff --git a/apps/fastapi/src/app/schemas/linear_b_prediction.py b/apps/fastapi/src/app/schemas/linear_b_prediction.py index ca5e134..8475463 100644 --- a/apps/fastapi/src/app/schemas/linear_b_prediction.py +++ b/apps/fastapi/src/app/schemas/linear_b_prediction.py @@ -8,8 +8,8 @@ # Define the structure of the CSV data for Linear B as a Pydantic model class LBPredictionResult(BaseModel): Peptide_Sequence: str - Linear_B_Cell_Immunogenicity: Optional[float] - Linear_BCR_Recognition: float + Linear_B_Cell_Immunogenicity: Optional[float] = Field(default=None) + Linear_BCR_Recognition: Optional[float] = Field(default=None) class LinearBPredictionCreate(CreateBase, JobMixin): diff --git a/apps/fastapi/src/app/schemas/mhc_i_prediction.py b/apps/fastapi/src/app/schemas/mhc_i_prediction.py index 7fbdd7e..44a8418 100644 --- a/apps/fastapi/src/app/schemas/mhc_i_prediction.py +++ b/apps/fastapi/src/app/schemas/mhc_i_prediction.py @@ -8,11 +8,11 @@ # Define the structure of the CSV data for MHC-I as a Pydantic model class MhcIPredictionResult(BaseModel): Peptide_Sequence: str - ClassI_TCR_Recognition: float - ClassI_MHC_Binding_Affinity: str - ClassI_pMHC_Stability: str - Best_Binding_Affinity: str - Best_pMHC_Stability: str + ClassI_TCR_Recognition: Optional[float] = Field(default=None) + ClassI_MHC_Binding_Affinity: Optional[str] = Field(default="") + ClassI_pMHC_Stability: Optional[str] = Field(default="") + Best_Binding_Affinity: Optional[str] = Field(default="") + Best_pMHC_Stability: Optional[str] = Field(default="") class MhcIPredictionCreate(CreateBase, JobMixin): diff --git a/apps/fastapi/src/app/schemas/mhc_ii_prediction.py b/apps/fastapi/src/app/schemas/mhc_ii_prediction.py index 5110047..8ad0155 100644 --- a/apps/fastapi/src/app/schemas/mhc_ii_prediction.py +++ b/apps/fastapi/src/app/schemas/mhc_ii_prediction.py @@ -8,11 +8,11 @@ # Define the structure of the CSV data for MHC-II as a Pydantic model class MhcIIPredictionResult(BaseModel): Peptide_Sequence: str - ClassII_TCR_Recognition: float - ClassII_MHC_Binding_Affinity: str - ClassII_pMHC_Stability: str - Best_Binding_Affinity: str - Best_pMHC_Stability: str + ClassI_TCR_Recognition: Optional[float] = Field(default=None) + ClassI_MHC_Binding_Affinity: Optional[str] = Field(default="") + ClassI_pMHC_Stability: Optional[str] = Field(default="") + Best_Binding_Affinity: Optional[str] = Field(default="") + Best_pMHC_Stability: Optional[str] = Field(default="") class MhcIIPredictionCreate(CreateBase, JobMixin): diff --git a/apps/fastapi/src/app/services/background_tasks.py b/apps/fastapi/src/app/services/background_tasks.py index 5c7b5c5..826d5f6 100644 --- a/apps/fastapi/src/app/services/background_tasks.py +++ b/apps/fastapi/src/app/services/background_tasks.py @@ -1,34 +1,36 @@ import logging -from typing import Optional +from typing import List, Optional from fastapi import HTTPException from supabase._async.client import AsyncClient -from app.core.config import settings from app.crud.crud_conformational_b_prediction import crud_conformational_b_prediction from app.crud.crud_job import crud_job from app.crud.crud_linear_b_prediction import crud_linear_b_prediction -from app.crud.crud_mhc_i_prediction import crud_mhc_i_prediction from app.crud.crud_mhc_ii_prediction import crud_mhc_ii_prediction -from app.services.pipeline import run +from app.services.inference import run_netmhci_binding_affinity_classI from app.services.postprocess import ( + postprocess_mhc_i_prediction, + process_classI_results, process_conformational_b_prediction, process_linear_b_prediction, - process_mhc_i_prediction, process_mhc_ii_prediction, ) +from app.services.preprocess import preprocess_protein_sequence logger = logging.getLogger(__name__) async def process_and_update_prediction( + db: AsyncClient, + user_id: str, job_id: str, prediction_type: str, - db: AsyncClient, sequence: Optional[str] = None, pdb_id: Optional[str] = None, # Only needed for Conformational B chain: Optional[str] = None, # Only needed for Conformational B is_structure_based: Optional[bool] = False, # Only needed for Conformational B + alleles: Optional[List[str]] = None, ): """ Background task to process a prediction based on its type and update the database. @@ -42,13 +44,13 @@ async def process_and_update_prediction( sample = {"inputs": glp_1_receptor} # Run prediction pipeline - predictions = run( - input_objects=[sample], - endpoint_name=settings.SAGEMAKER_ENDPOINT_NAME, - model_type="single", - ) + # predictions = run( + # input_objects=[sample], + # endpoint_name=settings.SAGEMAKER_ENDPOINT_NAME, + # model_type="single", + # ) - logger.info(predictions) + # logger.info(predictions) # Process the prediction based on its type if prediction_type == "conformational-b": @@ -68,11 +70,29 @@ async def process_and_update_prediction( db=db, job_id=job_id, result=results ) elif prediction_type == "mhc-i": - results = await process_mhc_i_prediction(sequence=sequence) - await crud_mhc_i_prediction.update_result( - db=db, job_id=job_id, result=results + # Step 1: Split protein sequence into peptides (preprocessing) + peptides = preprocess_protein_sequence(sequence, prediction_type) + + # Step 2: Run NetMHCpan-4.1 binding affinity predictions (inference) + netmhci_results = await run_netmhci_binding_affinity_classI( + peptides, alleles + ) + + # Step 3: Process NetMHCpan-4.1 results (postprocessing) + processed_results = await process_classI_results(netmhci_results) + + # Step 4: Process the results (postprocessing) + await postprocess_mhc_i_prediction( + db=db, + job_id=job_id, + results=processed_results, # Passing results from inference + user_id=user_id, + prediction_type=prediction_type, ) elif prediction_type == "mhc-ii": + # Step 1: Split protein sequence into peptides + peptides = preprocess_protein_sequence(sequence, prediction_type) + results = await process_mhc_ii_prediction(sequence=sequence) await crud_mhc_ii_prediction.update_result( db=db, job_id=job_id, result=results diff --git a/apps/fastapi/src/app/services/inference.py b/apps/fastapi/src/app/services/inference.py index 59f1e52..78579cd 100644 --- a/apps/fastapi/src/app/services/inference.py +++ b/apps/fastapi/src/app/services/inference.py @@ -1,11 +1,21 @@ +import asyncio import json +import logging +from typing import Any, Dict, List import boto3 +import httpx from app.core.config import settings +logger = logging.getLogger(__name__) -def get_predictions(requests, endpoint_name, model_name, model_type="mme"): + +IEDB_API_URL_CLASSI = "http://tools-cluster-interface.iedb.org/tools_api/mhci/" +IEDB_API_URL_CLASSII = "http://tools-cluster-interface.iedb.org/tools_api/mhcii/" + + +def get_sagemaker_predictions(requests, endpoint_name, model_name, model_type="mme"): """ Pass preprocessed requests to the specific endpoint (mme or single) """ @@ -29,3 +39,98 @@ def get_predictions(requests, endpoint_name, model_name, model_type="mme"): res = json.loads(response["Body"].read().decode("utf-8")) responses.append(res) return responses + + +async def run_netmhci_binding_affinity_classI( + peptides: List[str], alleles: List[str], method: str = "netmhcpan-4.1" +) -> List[Dict[str, Any]]: + """ + Uses IEDB API to generate binding affinity for each peptide and HLA interaction. + + Args: + peptides (list): A list of peptide sequences. + alleles (list): A list of HLA alleles for which to make predictions. + method (str): Prediction method to use. + + Returns: + list: A list of dictionaries containing the binding affinity results or errors. + """ + results = [] + peptides_by_length = {} + + # Group peptides by their length + for peptide in peptides: + length = len(peptide) + peptides_by_length.setdefault(length, []).append(peptide) + + async with httpx.AsyncClient() as client: + for allele in alleles: + for length, peptides_subset in peptides_by_length.items(): + sequence_text = "\n".join( + [ + f">peptide{i}\n{peptide}" + for i, peptide in enumerate(peptides_subset) + ] + ) + + payload = { + "method": method, + "sequence_text": sequence_text, + "allele": allele, + "length": str(length), + "species": "human", + } + + retries = 0 + max_retries = 5 + while retries < max_retries: + try: + response = await client.post(IEDB_API_URL_CLASSI, data=payload) + + # Handle 403 and 500 errors with retry logic + if response.status_code in [403, 500]: + retries += 1 + sleep_time = 2**retries # Exponential backoff + logger.error( + f"Server error {response.status_code}. Retrying in {sleep_time} seconds..." + ) + await asyncio.sleep(sleep_time) + else: + response.raise_for_status() # Raise error for any other issues + results.append( + { + "allele": allele, + "length": length, + "peptides": peptides_subset, + "result": response.text, + } + ) + logger.info( + f"Successfully retrieved data for allele {allele} and length {length}." + ) + break # Break loop on success + except httpx.RequestError as e: + if retries == max_retries: + logger.error( + f"Max retries reached for allele {allele} and length {length}: {e}" + ) + results.append( + { + "allele": allele, + "length": length, + "peptides": peptides_subset, + "error": str(e), + } + ) + else: + retries += 1 + sleep_time = 2**retries + logger.error( + f"Request error. Retrying in {sleep_time} seconds for allele {allele} and length {length}: {e}" + ) + await asyncio.sleep(sleep_time) + except Exception as e: + logger.error(f"Unexpected error: {e}") + break + + return results diff --git a/apps/fastapi/src/app/services/pipeline.py b/apps/fastapi/src/app/services/pipeline.py index 034c9b9..b64d5ae 100644 --- a/apps/fastapi/src/app/services/pipeline.py +++ b/apps/fastapi/src/app/services/pipeline.py @@ -1,7 +1,7 @@ from typing import Any, List from app.core.utils import get_endpoint -from app.services.inference import get_predictions +from app.services.inference import get_sagemaker_predictions def run( @@ -31,7 +31,7 @@ def run( preprocess = [obj for obj in input_objects] # Run inference - model_responses = get_predictions( + model_responses = get_sagemaker_predictions( preprocess, endpoint_name, model_name, diff --git a/apps/fastapi/src/app/services/postprocess.py b/apps/fastapi/src/app/services/postprocess.py index f33eeb0..07f33e2 100644 --- a/apps/fastapi/src/app/services/postprocess.py +++ b/apps/fastapi/src/app/services/postprocess.py @@ -1,10 +1,15 @@ import logging -from typing import List, Optional +from datetime import datetime +from io import StringIO +from typing import Any, Dict, List, Optional +import pandas as pd from fastapi import HTTPException +from supabase import AsyncClient from app.core.config import settings -from app.core.utils import read_s3_csv +from app.core.utils import generate_csv_key, read_s3_csv, upload_csv_to_s3 +from app.crud import crud_mhc_i_prediction from app.schemas.conformational_b_prediction import PredictionResult from app.schemas.linear_b_prediction import LBPredictionResult from app.schemas.mhc_i_prediction import MhcIPredictionResult @@ -71,22 +76,97 @@ async def process_linear_b_prediction(sequence: str) -> List[LBPredictionResult] return results -async def process_mhc_i_prediction(sequence: str) -> List[MhcIPredictionResult]: +# MHC-I Predictions +async def process_classI_results( + results: List[Dict[str, Any]], +) -> List[MhcIPredictionResult]: """ - Process an MHC-I prediction by reading a CSV file from S3 and validating results. + Processes the Class I results returned by the IEDB API. """ - csv_filename = "class_I.csv" - s3_key = f"data/{csv_filename}" # The S3 key for the file - - # Use the utility function to read the CSV and validate rows - results = read_s3_csv(settings.S3_BUCKET_NAME, s3_key, MhcIPredictionResult) + peptide_data = {} + + for res in results: + if "result" in res: + try: + df = pd.read_csv(StringIO(res["result"]), sep="\t") + if {"peptide", "allele", "ic50"}.issubset(df.columns): + for _, row in df.iterrows(): + peptide = row["peptide"] + allele = row["allele"] + ic50 = row["ic50"] + peptide_data.setdefault(peptide, {"binding_affinities": []}) + peptide_data[peptide]["binding_affinities"].append( + (allele, float(ic50)) + ) + else: + logger.warning(f"Unexpected columns in API response: {df.columns}") + except pd.errors.EmptyDataError: + logger.warning( + f"Received empty data from API for allele {res['allele']} and length {res['length']}." + ) + else: + # Handle errors if any + for peptide in res["peptides"]: + peptide_data.setdefault(peptide, {"binding_affinities": []}) + logger.error(f"Error for peptide {peptide}: {res.get('error')}") + + # Format processed results + processed_results = [] + for peptide, data in peptide_data.items(): + binding_affinities = data.get("binding_affinities", []) + binding_affinity_str = "|".join( + [f"{allele}={ic50} nM" for allele, ic50 in binding_affinities] + ) + best_binding_affinity = ( + f"{min(binding_affinities, key=lambda x: x[1])}" + if binding_affinities + else "" + ) - if not results: - raise HTTPException( - status_code=404, detail=f"CSV file not found in S3 for sequence {sequence}." + processed_results.append( + MhcIPredictionResult( + Peptide_Sequence=peptide, + ClassI_MHC_Binding_Affinity=binding_affinity_str, + Best_Binding_Affinity=best_binding_affinity, + ) ) - return results + return processed_results + + +async def postprocess_mhc_i_prediction( + db: AsyncClient, + job_id: str, + results: List[MhcIPredictionResult], # Results from inference + user_id: str, + prediction_type: str, +): + """ + Consolidated postprocessing function for MHC-I predictions. + - Processes the prediction results + - Uploads results as a CSV to S3 + - Updates the database with the results and CSV download URL + """ + # Step 1: Process the results + processed_results = results + + # Step 2: Generate a unique CSV filename and upload to S3 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + s3_key = generate_csv_key( + user_id=user_id, + job_id=job_id, + timestamp=timestamp, + prediction_type=prediction_type, + ) + await upload_csv_to_s3(processed_results, s3_key) + + # Step 3: Update the database with results and CSV URL + await crud_mhc_i_prediction.update_result( + db=db, + job_id=job_id, + result=processed_results, + csv_download_url=f"https://{settings.S3_BUCKET_NAME}.s3.amazonaws.com/{s3_key}", + ) async def process_mhc_ii_prediction(sequence: str) -> List[MhcIIPredictionResult]: diff --git a/apps/fastapi/src/app/services/preprocess.py b/apps/fastapi/src/app/services/preprocess.py index 904a531..ba5aea3 100644 --- a/apps/fastapi/src/app/services/preprocess.py +++ b/apps/fastapi/src/app/services/preprocess.py @@ -3,6 +3,8 @@ from fastapi import HTTPException +from app.core.utils import get_default_peptide_lengths, split_protein_sequence + logger = logging.getLogger(__name__) """ @@ -114,3 +116,12 @@ def validate_pdb_data(pdb_data: dict, chain: Optional[str] = None) -> Dict[str, sequence = extract_sequence(pdb_data) structure = extract_structure(pdb_data) if chain else None return {"sequence": sequence, "structure": structure} + + +def preprocess_protein_sequence(protein_sequence: str, prediction_type: str): + """ + Preprocess the protein sequence based on prediction type by splitting into peptides. + """ + min_length, max_length = get_default_peptide_lengths(prediction_type) + peptides = split_protein_sequence(protein_sequence, min_length, max_length) + return peptides diff --git a/apps/fastapi/uv.lock b/apps/fastapi/uv.lock index ed86e33..6d808b8 100644 --- a/apps/fastapi/uv.lock +++ b/apps/fastapi/uv.lock @@ -5,6 +5,39 @@ resolution-markers = [ "python_full_version >= '3.13'", ] +[[package]] +name = "aioboto3" +version = "13.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiobotocore", extra = ["boto3"] }, + { name = "aiofiles" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/35/1d/a36f39e95d15202236a5fec436377a9db712c5fe5a240325a5e54bc5e3ef/aioboto3-13.2.0.tar.gz", hash = "sha256:92c3232e0bf7dcb5d921cd1eb8c5e0b856c3985f7c1cd32ab3cd51adc5c9b5da", size = 32497 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/25/66/e4b2d8f3d11687f7c63b1b63e484ee879f9af637b3564026037655d83255/aioboto3-13.2.0-py3-none-any.whl", hash = "sha256:fd894b8d319934dfd75285b58da35560670e57182d0148c54a3d4ee5da730c78", size = 34738 }, +] + +[[package]] +name = "aiobotocore" +version = "2.15.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "aioitertools" }, + { name = "botocore" }, + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e3/3d/5d54985abed848a4d4dafd10d7eb9ecd6bd7fff9533223911a92c2e6e15d/aiobotocore-2.15.2.tar.gz", hash = "sha256:9ac1cfcaccccc80602968174aa032bf978abe36bd4e55e6781d6500909af1375", size = 107035 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/57/6402242dde160d9ef9903487b4277443dc3da04615f6c4d3b48564a8ab57/aiobotocore-2.15.2-py3-none-any.whl", hash = "sha256:d4d3128b4b558e2b4c369bfa963b022d7e87303adb82eec623cec8aa77ae578a", size = 77400 }, +] + +[package.optional-dependencies] +boto3 = [ + { name = "boto3" }, +] + [[package]] name = "aiofiles" version = "24.1.0" @@ -69,6 +102,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/8a/b4f3a8d0fb7f4fdb3869db6c3334e23e11878123605579e067be85f7e01f/aiohttp-3.10.8-cp313-cp313-win_amd64.whl", hash = "sha256:98a4eb60e27033dee9593814ca320ee8c199489fbc6b2699d0f710584db7feb7", size = 376618 }, ] +[[package]] +name = "aioitertools" +version = "0.12.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/06/de/38491a84ab323b47c7f86e94d2830e748780525f7a10c8600b67ead7e9ea/aioitertools-0.12.0.tar.gz", hash = "sha256:c2a9055b4fbb7705f561b9d86053e8af5d10cc845d22c32008c43490b2d8dd6b", size = 19369 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/13/58b70a580de00893223d61de8fea167877a3aed97d4a5e1405c9159ef925/aioitertools-0.12.0-py3-none-any.whl", hash = "sha256:fc1f5fac3d737354de8831cbba3eb04f79dd649d8f3afb4c5b114925e662a796", size = 24345 }, +] + [[package]] name = "aiosignal" version = "1.3.1" @@ -376,6 +418,7 @@ name = "fastapi-epitope-prediction" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "aioboto3" }, { name = "aiofiles" }, { name = "boto3" }, { name = "fastapi", extra = ["standard"] }, @@ -400,6 +443,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "aioboto3", specifier = ">=13.2.0" }, { name = "aiofiles", specifier = ">=24.1.0" }, { name = "boto3", specifier = ">=1.35.29" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.115.0" }, @@ -1465,6 +1509,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/56/27/96a5cd2626d11c8280656c6c71d8ab50fe006490ef9971ccd154e0c42cd2/websockets-13.1-py3-none-any.whl", hash = "sha256:a9a396a6ad26130cdae92ae10c36af09d9bfe6cafe69670fd3b6da9b07b4044f", size = 152134 }, ] +[[package]] +name = "wrapt" +version = "1.16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/95/4c/063a912e20bcef7124e0df97282a8af3ff3e4b603ce84c481d6d7346be0a/wrapt-1.16.0.tar.gz", hash = "sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d", size = 53972 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/92/17/224132494c1e23521868cdd57cd1e903f3b6a7ba6996b7b8f077ff8ac7fe/wrapt-1.16.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b", size = 37614 }, + { url = "https://files.pythonhosted.org/packages/6a/d7/cfcd73e8f4858079ac59d9db1ec5a1349bc486ae8e9ba55698cc1f4a1dff/wrapt-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36", size = 38316 }, + { url = "https://files.pythonhosted.org/packages/7e/79/5ff0a5c54bda5aec75b36453d06be4f83d5cd4932cc84b7cb2b52cee23e2/wrapt-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73", size = 86322 }, + { url = "https://files.pythonhosted.org/packages/c4/81/e799bf5d419f422d8712108837c1d9bf6ebe3cb2a81ad94413449543a923/wrapt-1.16.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809", size = 79055 }, + { url = "https://files.pythonhosted.org/packages/62/62/30ca2405de6a20448ee557ab2cd61ab9c5900be7cbd18a2639db595f0b98/wrapt-1.16.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b", size = 87291 }, + { url = "https://files.pythonhosted.org/packages/49/4e/5d2f6d7b57fc9956bf06e944eb00463551f7d52fc73ca35cfc4c2cdb7aed/wrapt-1.16.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81", size = 90374 }, + { url = "https://files.pythonhosted.org/packages/a6/9b/c2c21b44ff5b9bf14a83252a8b973fb84923764ff63db3e6dfc3895cf2e0/wrapt-1.16.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9", size = 83896 }, + { url = "https://files.pythonhosted.org/packages/14/26/93a9fa02c6f257df54d7570dfe8011995138118d11939a4ecd82cb849613/wrapt-1.16.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c", size = 91738 }, + { url = "https://files.pythonhosted.org/packages/a2/5b/4660897233eb2c8c4de3dc7cefed114c61bacb3c28327e64150dc44ee2f6/wrapt-1.16.0-cp312-cp312-win32.whl", hash = "sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc", size = 35568 }, + { url = "https://files.pythonhosted.org/packages/5c/cc/8297f9658506b224aa4bd71906447dea6bb0ba629861a758c28f67428b91/wrapt-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8", size = 37653 }, + { url = "https://files.pythonhosted.org/packages/ff/21/abdedb4cdf6ff41ebf01a74087740a709e2edb146490e4d9beea054b0b7a/wrapt-1.16.0-py3-none-any.whl", hash = "sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1", size = 23362 }, +] + [[package]] name = "yarl" version = "1.13.1" diff --git a/notebooks/SaprotHub.ipynb b/notebooks/SaprotHub.ipynb deleted file mode 100644 index 63cd04f..0000000 --- a/notebooks/SaprotHub.ipynb +++ /dev/null @@ -1,3189 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "UUdjG4-XsE0I" - }, - "source": [ - "# **SaprotHub: Making Protein Modeling Accessible to All Biologists**\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "This is **ColabSaprot**, the Colab version of [SaProt](https://github.com/westlake-repl/SaProt), a pre-trained protein language model designed for various downstream protein tasks.\n", - "\n", - "**ColabSaprot** is a platform where **Protein Language Models(PLMs)** are more accessible and user-friendly for biologists, enabling effortless model training and sharing within the scientific community.\n", - "\n", - "We've established the **SaprotHub**([website](https://huggingface.co/SaProtHub), [paper](https://www.biorxiv.org/content/10.1101/2024.05.24.595648v3)) for storing and sharing models and datasets, where you can explore extensive collections for specific protein prediction tasks.\n", - "\n", - "We hope ColabSaprot and SaprotHub can contribute to advancing biological research, fostering collaboration, and accelerating discoveries in the field. You can access [our paper](https://www.biorxiv.org/content/10.1101/2024.05.24.595648v3) for further details.\n", - "\n", - "Check these videos ([training](https://www.youtube.com/watch?v=r42z1hvYKfw), [predicting](https://www.youtube.com/watch?v=N5VMBwM_ukQ)) to see how to use ColabSaprot.\n", - "\n", - "Joining [**OPMC**](https://theopmc.github.io/) as an author of SaprotHub.\n", - "\n", - "ColabSaprot supports hundreds of [protein prediction tasks](https://github.com/westlake-repl/SaProtHub/blob/main/task_list.md).\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1nLb_im9sJWw" - }, - "source": [ - "## ColabSaprot\n", - "\n", - "| Function | Tutorial | Video |\n", - "| ------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |\n", - "| Train your model | [How to train your model](https://github.com/westlake-repl/SaprotHub/wiki/2.1:-Train-your-model) | - [YouTube](https://www.youtube.com/watch?v=r42z1hvYKfw)
- [Bilibili](https://www.bilibili.com/video/BV1HDhHeTEmH/?spm_id_from=333.337.search-card.all.click&vd_source=a418185fadee73ac65d8fab69eee0b52) |\n", - "| Classification/Regression Prediction | [How to use model for classification/regression prediction](https://github.com/westlake-repl/SaprotHub/wiki/3.1:-Classification-Regression-Prediction) | - [YouTube](https://www.youtube.com/watch?v=N5VMBwM_ukQ) |\n", - "| Mutational Effect Prediction | [How to use model for mutational effect prediction](https://github.com/westlake-repl/SaprotHub/wiki/3.2:-Mutational-Effect-Prediction) | - |\n", - "| Inverse Folding Prediction | [How to use model for inverse folding prediction](https://github.com/westlake-repl/SaprotHub/wiki/3.3:-Inverse-Folding-Prediction) | - |\n", - "| Contribute to SaprotHub | [How to contribute to SaprotHub](https://github.com/westlake-repl/SaprotHub/wiki/0.4:-Contribute-to-SaprotHub) | - |\n", - "\n", - "
\n", - "\n", - "**To view the content, please click on the first option in the left sidebar.**\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zVQ6vaQTjYO3" - }, - "source": [ - "## SaprotHub\n", - "\n", - "Find awesome models and datasets for specific protein task on SaprotHub!\n", - "\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3-dw1U1uBI7d" - }, - "source": [ - "# **1: Installation**\n", - "\n", - "## ⚠️SWITCH YOUR RUNTIME TYPE TO GPU\n", - "Before installing SaProt, please **SWITCH YOUR RUNTIME TYPE TO GPU!!!**\n", - "\n", - "> 📍Please check this [page](https://github.com/westlake-repl/SaprotHub/wiki/1.1:-Switch-your-runtime-type-to-GPU) to learn **how to switch your runtime type to GPU**.\n", - "\n", - "## ⚠️Maximum Runtime and Idle Timeout\n", - "\n", - "To ensure your program finishes properly, please avoid letting your computer go to **sleep** or remain **idle** for long periods.\n", - "\n", - "Please be aware of **the maximum runtime**, as your program may be automatically terminated when this limit is reached.\n", - "\n", - "| Plan | Maximum Runtime | Idle Timeout | Additional Features |\n", - "|-----------------|------------------|--------------|--------------------------------------------|\n", - "| **Free** | 12 hours | Yes | - |\n", - "| **Colab Pro** | Based on availability and usage patterns | Yes | Increased compute availability |\n", - "| **Pay As You Go**| Based on availability and usage patterns | Yes | Increased compute availability |\n", - "| **Colab Pro+** | Up to 24 hours | No | Background execution, continuous code execution |\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "Tgvb8ibwBI7d" - }, - "outputs": [], - "source": [ - "#@title **1.1: ▶️ Click the run button to install SaProt**\n", - "\n", - "#@markdown (Please waiting for 2-8 minutes to install...)\n", - "################################################################################\n", - "########################### install saprot #####################################\n", - "################################################################################\n", - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "import os\n", - "# Check whether the server is local or from google cloud\n", - "root_dir = os.getcwd()\n", - "\n", - "from google.colab import output\n", - "output.enable_custom_widget_manager()\n", - "\n", - "try:\n", - " import sys\n", - " sys.path.append(f\"{root_dir}/SaprotHub\")\n", - " import saprot\n", - " print(\"SaProt is installed successfully!\")\n", - " os.system(f\"chmod +x {root_dir}/SaprotHub/bin/*\")\n", - "\n", - "except ImportError:\n", - " print(\"Installing SaProt...\")\n", - " os.system(f\"rm -rf {root_dir}/SaprotHub\")\n", - " # !rm -rf /content/SaprotHub/\n", - "\n", - " !git clone https://github.com/westlake-repl/SaprotHub.git\n", - "\n", - " # !pip install /content/SaprotHub/saprot-0.4.7-py3-none-any.whl\n", - " os.system(f\"pip install -r {root_dir}/SaprotHub/requirements.txt\")\n", - " # !pip install -r /content/SaprotHub/requirements.txt\n", - "\n", - " os.system(f\"pip install {root_dir}/SaprotHub\")\n", - "\n", - "\n", - " os.system(f\"mkdir -p {root_dir}/SaprotHub/LMDB\")\n", - " os.system(f\"mkdir -p {root_dir}/SaprotHub/bin\")\n", - " os.system(f\"mkdir -p {root_dir}/SaprotHub/output\")\n", - " os.system(f\"mkdir -p {root_dir}/SaprotHub/datasets\")\n", - " os.system(f\"mkdir -p {root_dir}/SaprotHub/adapters/classification/Local\")\n", - " os.system(f\"mkdir -p {root_dir}/SaprotHub/adapters/regression/Local\")\n", - " os.system(f\"mkdir -p {root_dir}/SaprotHub/adapters/token_classification/Local\")\n", - " os.system(f\"mkdir -p {root_dir}/SaprotHub/adapters/pair_classification/Local\")\n", - " os.system(f\"mkdir -p {root_dir}/SaprotHub/adapters/pair_regression/Local\")\n", - " os.system(f\"mkdir -p {root_dir}/SaprotHub/structures\")\n", - " # !mkdir -p /content/SaprotHub/LMDB\n", - " # !mkdir -p /content/SaprotHub/bin\n", - " # !mkdir -p /content/SaprotHub/output\n", - " # !mkdir -p /content/SaprotHub/datasets\n", - " # !mkdir -p /content/SaprotHub/adapters/classification/Local\n", - " # !mkdir -p /content/SaprotHub/adapters/regression/Local\n", - " # !mkdir -p /content/SaprotHub/adapters/token_classification/Local\n", - " # !mkdir -p /content/SaprotHub/adapters/pair_classification/Local\n", - " # !mkdir -p /content/SaprotHub/adapters/pair_regression/Local\n", - " # !mkdir -p /content/SaprotHub/structures\n", - "\n", - " # !pip install gdown==v4.6.3 --force-reinstall --quiet\n", - " # os.system(\n", - " # f\"wget 'https://drive.usercontent.google.com/download?id=1B_9t3n_nlj8Y3Kpc_mMjtMdY0OPYa7Re&export=download&authuser=0' -O {root_dir}/SaprotHub/bin/foldseek\"\n", - " # )\n", - "\n", - " os.system(f\"chmod +x {root_dir}/SaprotHub/bin/*\")\n", - " # !chmod +x /content/SaprotHub/bin/foldseek\n", - " import sys\n", - " sys.path.append(f\"{root_dir}/SaprotHub\")\n", - "\n", - " # !mv /content/SaprotHub/ColabSaprotSetup/foldseek /content/SaprotHub/bin/\n", - "\n", - "################################################################################\n", - "################################################################################\n", - "################################## global ######################################\n", - "################################################################################\n", - "################################################################################\n", - "\n", - "import ipywidgets\n", - "import pandas as pd\n", - "import torch\n", - "import numpy as np\n", - "import lmdb\n", - "import base64\n", - "import copy\n", - "import os\n", - "import json\n", - "import zipfile\n", - "import yaml\n", - "import argparse\n", - "import pprint\n", - "import subprocess\n", - "import py3Dmol\n", - "import matplotlib.pyplot as plt\n", - "import shutil\n", - "import torch.nn.functional as F\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n", - "\n", - "from loguru import logger\n", - "from easydict import EasyDict\n", - "from colorama import init, Fore, Back, Style\n", - "from IPython.display import clear_output\n", - "from huggingface_hub import snapshot_download\n", - "from ipywidgets import HTML\n", - "from IPython.display import display\n", - "from google.colab import widgets\n", - "from google.colab import files\n", - "from pathlib import Path\n", - "from tqdm import tqdm\n", - "from datetime import datetime\n", - "from transformers import AutoTokenizer, EsmForProteinFolding, EsmTokenizer\n", - "from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein\n", - "from transformers.models.esm.openfold_utils.feats import atom14_to_atom37\n", - "from string import ascii_uppercase,ascii_lowercase\n", - "from saprot.utils.mpr import MultipleProcessRunnerSimplifier\n", - "from saprot.data.parse import get_chain_ids\n", - "from saprot.scripts.training import my_load_model\n", - "from safetensors import safe_open\n", - "\n", - "DATASET_HOME = Path(f'{root_dir}/SaprotHub/datasets')\n", - "ADAPTER_HOME = Path(f'{root_dir}/SaprotHub/adapters')\n", - "STRUCTURE_HOME = Path(f\"{root_dir}/SaprotHub/structures\")\n", - "LMDB_HOME = Path(f'{root_dir}/SaprotHub/LMDB')\n", - "OUTPUT_HOME = Path(f'{root_dir}/SaprotHub/output')\n", - "UPLOAD_FILE_HOME = Path(f'{root_dir}/SaprotHub/upload_files')\n", - "FOLDSEEK_PATH = Path(f\"{root_dir}/SaprotHub/bin/foldseek\")\n", - "aa_set = {\"A\", \"C\", \"D\", \"E\", \"F\", \"G\", \"H\", \"I\", \"K\", \"L\", \"M\", \"N\", \"P\", \"Q\", \"R\", \"S\", \"T\", \"V\", \"W\", \"Y\"}\n", - "foldseek_struc_vocab = \"pynwrqhgdlvtmfsaeikc#\"\n", - "\n", - "data_type_list = [\"Single AA Sequence\",\n", - " \"Single SA Sequence\",\n", - " \"Single UniProt ID\",\n", - " \"Single PDB/CIF Structure\",\n", - " \"Multiple AA Sequences\",\n", - " \"Multiple SA Sequences\",\n", - " \"Multiple UniProt IDs\",\n", - " \"Multiple PDB/CIF Structures\",\n", - " \"SaprotHub Dataset\",\n", - " \"A pair of AA Sequences\",\n", - " \"A pair of SA Sequences\",\n", - " \"A pair of UniProt IDs\",\n", - " \"A pair of PDB/CIF Structures\",\n", - " \"Multiple pairs of AA Sequences\",\n", - " \"Multiple pairs of SA Sequences\",\n", - " \"Multiple pairs of UniProt IDs\",\n", - " \"Multiple pairs of PDB/CIF Structures\",]\n", - "\n", - "data_type_list_single = [\n", - " \"Single AA Sequence\",\n", - " \"Single SA Sequence\",\n", - " \"Single UniProt ID\",\n", - " \"Single PDB/CIF Structure\",\n", - " \"A pair of AA Sequences\",\n", - " \"A pair of SA Sequences\",\n", - " \"A pair of UniProt IDs\",\n", - " \"A pair of PDB/CIF Structures\",]\n", - "\n", - "data_type_list_multiple = [\n", - " \"Multiple AA Sequences\",\n", - " \"Multiple SA Sequences\",\n", - " \"Multiple UniProt IDs\",\n", - " \"Multiple PDB/CIF Structures\",\n", - " \"Multiple pairs of AA Sequences\",\n", - " \"Multiple pairs of SA Sequences\",\n", - " \"Multiple pairs of UniProt IDs\",\n", - " \"Multiple pairs of PDB/CIF Structures\",]\n", - "\n", - "task_type_dict = {\n", - " \"Protein-level Classification\": \"classification\",\n", - " \"Residue-level Classification\" : \"token_classification\",\n", - " \"Protein-level Regression\" : \"regression\",\n", - " \"Protein-protein Classification\": \"pair_classification\",\n", - " \"Protein-protein Regression\": \"pair_regression\",\n", - "}\n", - "model_type_dict = {\n", - " \"classification\" : \"saprot/saprot_classification_model\",\n", - " \"token_classification\" : \"saprot/saprot_token_classification_model\",\n", - " \"regression\" : \"saprot/saprot_regression_model\",\n", - " \"pair_classification\" : \"saprot/saprot_pair_classification_model\",\n", - " \"pair_regression\" : \"saprot/saprot_pair_regression_model\",\n", - "}\n", - "dataset_type_dict = {\n", - " \"classification\": \"saprot/saprot_classification_dataset\",\n", - " \"token_classification\" : \"saprot/saprot_token_classification_dataset\",\n", - " \"regression\": \"saprot/saprot_regression_dataset\",\n", - " \"pair_classification\" : \"saprot/saprot_pair_classification_dataset\",\n", - " \"pair_regression\" : \"saprot/saprot_pair_regression_dataset\",\n", - "}\n", - "training_data_type_dict = {\n", - " \"Single AA Sequence\": \"AA\",\n", - " \"Single SA Sequence\": \"SA\",\n", - " \"Single UniProt ID\": \"SA\",\n", - " \"Single PDB/CIF Structure\": \"SA\",\n", - " \"Multiple AA Sequences\": \"AA\",\n", - " \"Multiple SA Sequences\": \"SA\",\n", - " \"Multiple UniProt IDs\": \"SA\",\n", - " \"Multiple PDB/CIF Structures\": \"SA\",\n", - " \"SaprotHub Dataset\": \"SA\",\n", - " \"A pair of AA Sequences\": \"AA\",\n", - " \"A pair of SA Sequences\": \"SA\",\n", - " \"A pair of UniProt IDs\": \"SA\",\n", - " \"A pair of PDB/CIF Structures\": \"SA\",\n", - " \"Multiple pairs of AA Sequences\": \"AA\",\n", - " \"Multiple pairs of SA Sequences\": \"SA\",\n", - " \"Multiple pairs of UniProt IDs\": \"SA\",\n", - " \"Multiple pairs of PDB/CIF Structures\": \"SA\",\n", - "}\n", - "\n", - "\n", - "class font:\n", - " RED = '\\033[91m'\n", - " GREEN = '\\033[92m'\n", - " YELLOW = '\\033[93m'\n", - " BLUE = '\\033[94m'\n", - "\n", - " BOLD = '\\033[1m'\n", - " UNDERLINE = '\\033[4m'\n", - "\n", - " RESET = '\\033[0m'\n", - "\n", - "\n", - "################################################################################\n", - "############################### adapters #######################################\n", - "################################################################################\n", - "def get_adapters_list(task_type=None):\n", - "\n", - " adapters_list = []\n", - "\n", - " if task_type:\n", - " for file_path in (ADAPTER_HOME / task_type).glob('**/adapter_config.json'):\n", - " adapters_list.append(file_path.relative_to(ADAPTER_HOME / task_type).parent)\n", - " else:\n", - " for file_path in ADAPTER_HOME.glob('**/adapter_config.json'):\n", - " adapters_list.append(file_path.relative_to(ADAPTER_HOME).parent)\n", - "\n", - " return adapters_list\n", - "\n", - "def adapters_text(adapters_list):\n", - " input = ipywidgets.Text(\n", - " value=None,\n", - " placeholder='Enter SaprotHub Model ID',\n", - " # description='Selected:',\n", - " disabled=False)\n", - " input.layout.width = '500px'\n", - " display(input)\n", - "\n", - " return input\n", - "\n", - "def adapters_dropdown(adapters_list):\n", - " dropdown = ipywidgets.Dropdown(\n", - " # options=[f\"{adapter_path.parent.stem}/{adapter_path.stem}\" for index, adapter_path in enumerate(adapters_list)],\n", - " options=adapters_list,\n", - " value=None,\n", - " placeholder='Select a Local Model here',\n", - " # description='Selected:',\n", - " disabled=False)\n", - " dropdown.layout.width = '500px'\n", - " display(dropdown)\n", - "\n", - " return dropdown\n", - "\n", - "def adapters_combobox(adapters_list):\n", - " combobox = ipywidgets.Combobox(\n", - " options=[f\"{adapter_path.parent.stem}/{adapter_path.stem}\" for index, adapter_path in enumerate(adapters_list)],\n", - " value=None,\n", - " placeholder='Enter SaprotHub Model repository id or select a Local Model here',\n", - " # description='Selected:',\n", - " disabled=False)\n", - " combobox.layout.width = '500px'\n", - " display(combobox)\n", - "\n", - " return combobox\n", - "\n", - "def adapters_selectmultiple(adapters_list):\n", - " selectmulitiple = ipywidgets.SelectMultiple(\n", - " # options=[f\"{adapter_path.parent.stem}/{adapter_path.stem}\" for index, adapter_path in enumerate(adapters_list)],\n", - " options=adapters_list,\n", - " value=[],\n", - " #rows=10,\n", - " placeholder='Select multiple models',\n", - " # description='Fruits',\n", - " disabled=False,\n", - " layout={'width': '500px'})\n", - " display(selectmulitiple)\n", - "\n", - " return selectmulitiple\n", - "\n", - "def adapters_textmultiple(adapters_list):\n", - " textmultiple = ipywidgets.Text(\n", - " value=None,\n", - " placeholder='Enter multiple SaprotHub Model IDs, separated by commas.',\n", - " # description='Fruits',\n", - " disabled=False,\n", - " layout={'width': '500px'})\n", - " display(textmultiple)\n", - "\n", - " return textmultiple\n", - "\n", - "\n", - "def select_adapter_from(task_type, use_model_from):\n", - " adapters_list = get_adapters_list(task_type)\n", - "\n", - " if use_model_from == 'Trained by yourself on ColabSaprot':\n", - " print(Fore.BLUE+f\"Local Model ({task_type}):\"+Style.RESET_ALL)\n", - " return adapters_dropdown(adapters_list)\n", - "\n", - " elif use_model_from == 'Shared by peers on SaprotHub':\n", - " print(Fore.BLUE+\"SaprotHub Model:\"+Style.RESET_ALL)\n", - " return adapters_text(adapters_list)\n", - "\n", - " elif use_model_from == \"Saved in your local computer\":\n", - " print(Fore.BLUE+\"Click the button to upload the \\\"Model--.zip\\\" file of your Model:\"+Style.RESET_ALL)\n", - " # 1. upload model.zip\n", - " if task_type:\n", - " adapter_upload_path = ADAPTER_HOME / task_type / \"Local\"\n", - " else:\n", - " adapter_upload_path = ADAPTER_HOME / \"Local\"\n", - "\n", - " adapter_zip_path = upload_file(adapter_upload_path)\n", - " adapter_path = adapter_upload_path / adapter_zip_path.stem\n", - " # 2. unzip model.zip\n", - " with zipfile.ZipFile(adapter_zip_path, 'r') as zip_ref:\n", - " zip_ref.extractall(adapter_path)\n", - " os.remove(adapter_zip_path)\n", - " # 3. check adapter_config.json\n", - " adapter_config_path = adapter_path / \"adapter_config.json\"\n", - " assert adapter_config_path.exists(), f\"Can't find {adapter_config_path}\"\n", - "\n", - " # # 4. move to correct folder\n", - " # num_labels, task_type = get_num_labels_and_task_type_by_adapter(adapter_path)\n", - " # shutil.move(adapter_path, ADAPTER_HOME / task_type)\n", - "\n", - " return EasyDict({\"value\": f\"Local/{adapter_zip_path.stem}\"})\n", - "\n", - " elif use_model_from == \"Multi-models on ColabSaprot\":\n", - " # 1. select the list of adapters\n", - " print(Fore.BLUE+f\"Local Model ({task_type}):\"+Style.RESET_ALL)\n", - " print(Fore.BLUE+f\"Multiple values can be selected with \\\"shift\\\" and/or \\\"ctrl\\\" (or \\\"command\\\") pressed and mouse clicks or arrow keys.\"+Style.RESET_ALL)\n", - " return adapters_selectmultiple(adapters_list)\n", - "\n", - " elif use_model_from == \"Multi-models on SaprotHub\":\n", - " # 1. enter the list of adapters\n", - " print(Fore.BLUE+f\"SaprotHub Model IDs, separated by commas ({task_type}):\"+Style.RESET_ALL)\n", - " return adapters_textmultiple(adapters_list)\n", - "\n", - "\n", - "\n", - "################################################################################\n", - "########################### download dataset ###################################\n", - "################################################################################\n", - "def download_dataset(task_name):\n", - " import gdown\n", - " import tarfile\n", - "\n", - " filepath = LMDB_HOME / f\"{task_name}.tar.gz\"\n", - " download_links = {\n", - " \"ClinVar\" : \"https://drive.google.com/uc?id=1Le6-v8ddXa1eLJZFo7HPij7NhaBmNUbo\",\n", - " \"DeepLoc_cls2\" : \"https://drive.google.com/uc?id=1dGlojkCt1DwUXWiUk4kXRGRNu5sz2uxf\",\n", - " \"DeepLoc_cls10\" : \"https://drive.google.com/uc?id=1dGlojkCt1DwUXWiUk4kXRGRNu5sz2uxf\",\n", - " \"EC\" : \"https://drive.google.com/uc?id=1VFLFA-jK1tkTZBVbMw8YSsjZqAqlVQVQ\",\n", - " \"GO_BP\" : \"https://drive.google.com/uc?id=1DGiGErWbRnEK8jmE2Jpb996By8KVDBfF\",\n", - " \"GO_CC\" : \"https://drive.google.com/uc?id=1DGiGErWbRnEK8jmE2Jpb996By8KVDBfF\",\n", - " \"GO_MF\" : \"https://drive.google.com/uc?id=1DGiGErWbRnEK8jmE2Jpb996By8KVDBfF\",\n", - " \"HumanPPI\" : \"https://drive.google.com/uc?id=1ahgj-IQTtv3Ib5iaiXO_ASh2hskEsvoX\",\n", - " \"MetalIonBinding\" : \"https://drive.google.com/uc?id=1rwknPWIHrXKQoiYvgQy4Jd-efspY16x3\",\n", - " \"ProteinGym\" : \"https://drive.google.com/uc?id=1L-ODrhfeSjDom-kQ2JNDa2nDEpS8EGfD\",\n", - " \"Thermostability\" : \"https://drive.google.com/uc?id=1I9GR1stFDHc8W3FCsiykyrkNprDyUzSz\",\n", - " }\n", - "\n", - " try:\n", - " gdown.download(download_links[task_name], str(filepath), quiet=False)\n", - " with tarfile.open(filepath, 'r:gz') as tar:\n", - " tar.extractall(path=str(LMDB_HOME))\n", - " print(f\"Extracted: {filepath}\")\n", - " except Exception as e:\n", - " raise RuntimeError(\"The dataset has not prepared.\")\n", - "\n", - "################################################################################\n", - "############################# upload file ######################################\n", - "################################################################################\n", - "def upload_file(upload_path):\n", - " upload_path = Path(upload_path)\n", - " upload_path.mkdir(parents=True, exist_ok=True)\n", - " basepath = Path().resolve()\n", - " try:\n", - " uploaded = files.upload()\n", - " filenames = []\n", - " for filename in uploaded.keys():\n", - " filenames.append(filename)\n", - " shutil.move(basepath / filename, upload_path / filename)\n", - " if len(filenames) == 0:\n", - " logger.info(\"The uploading process has been interrupted by the user.\")\n", - " raise RuntimeError(\"The uploading process has been interrupted by the user.\")\n", - " except Exception as e:\n", - " logger.error(\"Upload file fail! Please click the button to run again.\")\n", - " raise(e)\n", - "\n", - " return upload_path / filenames[0]\n", - "\n", - "################################################################################\n", - "############################ upload dataset ####################################\n", - "################################################################################\n", - "\n", - "def read_csv_dataset(uploaded_csv_path):\n", - " df = pd.read_csv(uploaded_csv_path)\n", - " df.columns = df.columns.str.lower()\n", - " return df\n", - "\n", - "def check_column_label_and_stage(csv_dataset_path):\n", - " df = read_csv_dataset(csv_dataset_path)\n", - " assert {'label', 'stage'}.issubset(df.columns), f\"Make sure your CSV dataset includes both `label` and `stage` columns!\\nCurrent columns: {df.columns}\"\n", - " column_values = set(df['stage'].unique())\n", - " assert all(value in column_values for value in ['train', 'valid', 'test']), f\"Ensure your dataset includes samples for all three stages: `train`, `valid` and `test`.\\nCurrent columns: {df.columns}\"\n", - "\n", - "def get_data_type(csv_dataset_path):\n", - " # AA, SA, Pair AA, Pair SA\n", - " df = read_csv_dataset(csv_dataset_path)\n", - "\n", - " # AA, SA\n", - " if 'sequence' in df.columns:\n", - " second_token = df.loc[0, 'sequence'][1]\n", - " if second_token in aa_set:\n", - " return \"Multiple AA Sequences\"\n", - " elif second_token in foldseek_struc_vocab:\n", - " return \"Multiple SA Sequences\"\n", - " else:\n", - " raise RuntimeError(f\"The sequence in the dataset({csv_dataset_path}) are neither SA Sequences nor AA Sequences. Please check carefully.\")\n", - "\n", - " # Pair AA, Pair SA\n", - " elif 'sequence_1' in df.columns and 'sequence_2' in df.columns:\n", - " second_token = df.loc[0, 'sequence_1'][1]\n", - " if second_token in aa_set:\n", - " return \"Multiple pairs of AA Sequences\"\n", - " elif second_token in foldseek_struc_vocab:\n", - " return \"Multiple pairs of SA Sequences\"\n", - " else:\n", - " raise RuntimeError(f\"The sequence in the dataset({csv_dataset_path}) are neither SA Sequences nor AA Sequences. Please check carefully.\")\n", - "\n", - " else:\n", - " raise RuntimeError(f\"The data type of the dataset({csv_dataset_path}) should be one of the following types: Multiple AA Sequences, Multiple SA Sequences, Multiple pairs of AA Sequences, Multiple pairs of SA Sequences\")\n", - "\n", - "def check_task_type_and_data_type(original_task_type, data_type):\n", - " if \"Protein-protein\" in original_task_type:\n", - " assert data_type == \"SaprotHub Dataset\" or \"pair\" in data_type, f\"The current `data_type`({data_type}) is incompatible with the current `task_type`({original_task_type}). Please use Pair Sequence Datset for {original_task_type} task!\"\n", - " else:\n", - " assert \"pair\" not in data_type, f\"The current `data_type`({data_type}) is incompatible with the current `task_type`({original_task_type}). Please avoid using the Pair Sequence Dataset({data_type}) for the {original_task_type} task!\"\n", - "\n", - "def input_raw_data_by_data_type(data_type):\n", - " print(Fore.BLUE+\"Dataset: \"+Style.RESET_ALL, end='')\n", - "\n", - " # 0-2. 0. Single AA Sequence, 1. Single SA Sequence, 2. Single UniProt ID\n", - " if data_type in data_type_list[:3]:\n", - " input_seq = ipywidgets.Text(\n", - " value=None,\n", - " placeholder=f'Enter {data_type} here',\n", - " disabled=False)\n", - " input_seq.layout.width = '500px'\n", - " print(Fore.BLUE+f\"{data_type}\"+Style.RESET_ALL)\n", - " display(input_seq)\n", - " return input_seq\n", - "\n", - " # 3. Single PDB/CIF Structure\n", - " elif data_type == 'Single PDB/CIF Structure':\n", - " print(\"Please provide the structure type, chain and your structure file.\")\n", - "\n", - " dropdown_type = ipywidgets.Dropdown(\n", - " value=\"AF2\",\n", - " options=[\"PDB\", \"AF2\"],\n", - " disabled=False)\n", - " dropdown_type.layout.width = '500px'\n", - " print(Fore.BLUE+\"Structure type:\"+Style.RESET_ALL)\n", - " display(dropdown_type)\n", - "\n", - " input_chain = ipywidgets.Text(\n", - " value=\"A\",\n", - " placeholder=f'Enter the name of chain here',\n", - " disabled=False)\n", - " input_chain.layout.width = '500px'\n", - " print(Fore.BLUE+\"Chain:\"+Style.RESET_ALL)\n", - " display(input_chain)\n", - "\n", - " print(Fore.BLUE+\"Please upload a .pdb/.cif file\"+Style.RESET_ALL)\n", - " pdb_file_path = upload_file(STRUCTURE_HOME)\n", - " return pdb_file_path.stem, dropdown_type, input_chain\n", - "\n", - " # 4-7 & 13-16. Multiple Sequences\n", - " elif data_type in data_type_list_multiple:\n", - " print(Fore.BLUE+f\"Please upload the .csv file which contains {data_type}\"+Style.RESET_ALL)\n", - " uploaded_csv_path = upload_file(UPLOAD_FILE_HOME)\n", - " print(Fore.BLUE+\"Successfully upload your .csv file!\"+Style.RESET_ALL)\n", - " print(\"=\"*100)\n", - "\n", - " if data_type in ['Multiple PDB/CIF Structures', 'Multiple pairs of PDB/CIF Structures']:\n", - " # upload and unzip PDB files\n", - " print(Fore.BLUE+f\"Please upload your .zip file which contains {data_type} files\"+Style.RESET_ALL)\n", - " pdb_zip_path = upload_file(UPLOAD_FILE_HOME)\n", - " if pdb_zip_path.suffix != \".zip\":\n", - " logger.error(\"The data type does not match. Please click the run button again to upload a .zip file!\")\n", - " raise RuntimeError(\"The data type does not match.\")\n", - " print(Fore.BLUE+\"Successfully upload your .zip file!\"+Style.RESET_ALL)\n", - " print(\"=\"*100)\n", - "\n", - " import zipfile\n", - " with zipfile.ZipFile(pdb_zip_path, 'r') as zip_ref:\n", - " zip_ref.extractall(STRUCTURE_HOME)\n", - "\n", - " return uploaded_csv_path\n", - "\n", - " # 8. SaprotHub Dataset\n", - " elif data_type == \"SaprotHub Dataset\":\n", - " input_repo_id = ipywidgets.Text(\n", - " value=None,\n", - " placeholder=f'Copy and paste the SaprotHub Dataset ID here',\n", - " disabled=False)\n", - " input_repo_id.layout.width = '500px'\n", - " print(Fore.BLUE+f\"{data_type}\"+Style.RESET_ALL)\n", - " display(input_repo_id)\n", - " return input_repo_id\n", - "\n", - " # 9-11. A pair of seq\n", - " elif data_type in [\"A pair of AA Sequences\", \"A pair of SA Sequences\", \"A pair of UniProt IDs\"]:\n", - " print()\n", - "\n", - " seq_type = data_type[len(\"A pair of \"):-1]\n", - "\n", - " input_seq1 = ipywidgets.Text(\n", - " value=None,\n", - " placeholder=f'Enter the {seq_type} of Sequence 1 here',\n", - " disabled=False)\n", - " input_seq1.layout.width = '500px'\n", - " print(Fore.BLUE+f\"Sequence 1:\"+Style.RESET_ALL)\n", - " display(input_seq1)\n", - "\n", - " input_seq2 = ipywidgets.Text(\n", - " value=None,\n", - " placeholder=f'Enter the {seq_type} of Sequence 2 here',\n", - " disabled=False)\n", - " input_seq2.layout.width = '500px'\n", - " print(Fore.BLUE+f\"Sequence 2:\"+Style.RESET_ALL)\n", - " display(input_seq2)\n", - "\n", - " return (input_seq1, input_seq2)\n", - "\n", - " # 12. Pair Single PDB/CIF Structure\n", - " elif data_type == 'A pair of PDB/CIF Structures':\n", - " print(\"Please provide the structure type, chain and your structure file.\")\n", - "\n", - " dropdown_type1 = ipywidgets.Dropdown(\n", - " value=\"PDB\",\n", - " options=[\"PDB\", \"AF2\"],\n", - " disabled=False)\n", - " dropdown_type1.layout.width = '500px'\n", - " print(Fore.BLUE+\"The first structure type:\"+Style.RESET_ALL)\n", - " display(dropdown_type1)\n", - "\n", - " input_chain1 = ipywidgets.Text(\n", - " value=\"A\",\n", - " placeholder=f'Enter the name of chain of the first structure here',\n", - " disabled=False)\n", - " input_chain1.layout.width = '500px'\n", - " print(Fore.BLUE+\"Chain of the first structure:\"+Style.RESET_ALL)\n", - " display(input_chain1)\n", - "\n", - " print(Fore.BLUE+\"Please upload a .pdb/.cif file\"+Style.RESET_ALL)\n", - " pdb_file_path1 = upload_file(STRUCTURE_HOME)\n", - "\n", - "\n", - " dropdown_type2 = ipywidgets.Dropdown(\n", - " value=\"PDB\",\n", - " options=[\"PDB\", \"AF2\"],\n", - " disabled=False)\n", - " dropdown_type2.layout.width = '500px'\n", - " print(Fore.BLUE+\"The second structure type:\"+Style.RESET_ALL)\n", - " display(dropdown_type2)\n", - "\n", - " input_chain2 = ipywidgets.Text(\n", - " value=\"A\",\n", - " placeholder=f'Enter the name of chain of the second structure here',\n", - " disabled=False)\n", - " input_chain2.layout.width = '500px'\n", - " print(Fore.BLUE+\"Chain of the second structure:\"+Style.RESET_ALL)\n", - " display(input_chain2)\n", - "\n", - " print(Fore.BLUE+\"Please upload a .pdb/.cif file\"+Style.RESET_ALL)\n", - " pdb_file_path2 = upload_file(STRUCTURE_HOME)\n", - " return (pdb_file_path1.stem, dropdown_type1, input_chain1, pdb_file_path2.stem, dropdown_type2, input_chain2)\n", - "\n", - "def get_SA_sequence_by_data_type(data_type, raw_data):\n", - "\n", - " # Multiple sequences\n", - " # raw_data = upload_files/xxx.csv\n", - "\n", - " # 8. SaprotHub Dataset\n", - " if data_type == \"SaprotHub Dataset\":\n", - " input_repo_id = raw_data\n", - " REPO_ID = input_repo_id.value\n", - "\n", - " if REPO_ID.startswith('/'):\n", - " return Path(REPO_ID)\n", - "\n", - " snapshot_download(repo_id=REPO_ID, repo_type=\"dataset\", local_dir=DATASET_HOME / REPO_ID)\n", - " csv_dataset_path = DATASET_HOME / REPO_ID / 'dataset.csv'\n", - " assert csv_dataset_path.exists(), f\"Can't find {csv_dataset_path}\"\n", - " protein_df = read_csv_dataset(csv_dataset_path)\n", - "\n", - " data_type = get_data_type(csv_dataset_path)\n", - "\n", - " return get_SA_sequence_by_data_type(data_type, csv_dataset_path)\n", - "\n", - " # # AA, SA\n", - " # if data_type == \"Multiple AA Sequences\":\n", - " # for index, value in protein_df['sequence'].items():\n", - " # sa_seq = ''\n", - " # for aa in value:\n", - " # sa_seq += aa + '#'\n", - " # protein_df.at[index, 'sequence'] = sa_seq\n", - "\n", - " # # Pair AA, Pair SA\n", - " # elif data_type in [\"Multiple pairs of AA Sequences\", \"Multiple pairs of SA Sequences\"]:\n", - " # for i in ['1', '2']:\n", - " # if data_type == \"Multiple pairs of AA Sequences\":\n", - " # for index, value in protein_df[f'sequence_{i}'].items():\n", - " # sa_seq = ''\n", - " # for aa in value:\n", - " # sa_seq += aa + '#'\n", - " # protein_df.at[index, f'sequence_{i}'] = sa_seq\n", - "\n", - " # protein_df[f'name_{i}'] = f'name_{i}'\n", - " # protein_df[f'chain_{i}'] = 'A'\n", - "\n", - " # protein_df.to_csv(csv_dataset_path, index=None)\n", - "\n", - " # return csv_dataset_path\n", - "\n", - " elif data_type in data_type_list_multiple:\n", - " uploaded_csv_path = raw_data\n", - " csv_dataset_path = DATASET_HOME / uploaded_csv_path.name\n", - " protein_df = read_csv_dataset(uploaded_csv_path)\n", - "\n", - " if 'pair' in data_type:\n", - " assert {'sequence_1', 'sequence_2'}.issubset(protein_df.columns), f\"The CSV dataset ({uploaded_csv_path}) must contain `sequence_1` and `sequence_2` columns. \\n Current columns:{protein_df.columns}\"\n", - " else:\n", - " assert 'sequence' in protein_df.columns, f\"The CSV Dataset({uploaded_csv_path}) must contain a `sequence` column. \\n Current columns:{protein_df.columns}\"\n", - "\n", - " # 4. Multiple AA Sequences\n", - " if data_type == 'Multiple AA Sequences':\n", - " for index, value in protein_df['sequence'].items():\n", - " sa_seq = ''\n", - " for aa in value:\n", - " sa_seq += aa + '#'\n", - " protein_df.at[index, 'sequence'] = sa_seq\n", - "\n", - " protein_df.to_csv(csv_dataset_path, index=None)\n", - " return csv_dataset_path\n", - "\n", - " # 5. Multiple SA Sequences\n", - " elif data_type == 'Multiple SA Sequences':\n", - " protein_df.to_csv(csv_dataset_path, index=None)\n", - " return csv_dataset_path\n", - "\n", - " # 6. Multiple UniProt IDs\n", - " elif data_type == 'Multiple UniProt IDs':\n", - " protein_list = protein_df.loc[:, 'sequence'].tolist()\n", - " uniprot2pdb(protein_list)\n", - " protein_list = [(uniprot_id, \"AF2\", \"A\") for uniprot_id in protein_list]\n", - " mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)\n", - " outputs = mprs.run()\n", - "\n", - " protein_df['sequence'] = [output.split(\"\\t\")[1] for output in outputs]\n", - " protein_df.to_csv(csv_dataset_path, index=None)\n", - " return csv_dataset_path\n", - "\n", - " # 7. Multiple PDB/CIF Structures\n", - " elif data_type == 'Multiple PDB/CIF Structures':\n", - " # protein_list = [(uniprot_id, type, chain), ...]\n", - " # protein_list = [item.split('.')[0] for item in protein_df.iloc[:, 0].tolist()]\n", - " # uniprot2pdb(protein_list)\n", - " protein_list = []\n", - " for row_tuple in protein_df.itertuples(index=False):\n", - " assert row_tuple.type in ['PDB', 'AF2'], \"The type of structure must be either \\\"PDB\\\" or \\\"AF2\\\"!\"\n", - " protein_list.append(row_tuple)\n", - " mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)\n", - " outputs = mprs.run()\n", - "\n", - " protein_df['sequence'] = [output.split(\"\\t\")[1] for output in outputs]\n", - " protein_df.to_csv(csv_dataset_path, index=None)\n", - " return csv_dataset_path\n", - "\n", - " # 13. Pair Multiple AA Sequences\n", - " elif data_type == \"Multiple pairs of AA Sequences\":\n", - " for i in ['1', '2']:\n", - " for index, value in protein_df[f'sequence_{i}'].items():\n", - " sa_seq = ''\n", - " for aa in value:\n", - " sa_seq += aa + '#'\n", - " protein_df.at[index, f'sequence_{i}'] = sa_seq\n", - "\n", - " protein_df[f'name_{i}'] = f'name_{i}'\n", - " protein_df[f'chain_{i}'] = 'A'\n", - "\n", - " protein_df.to_csv(csv_dataset_path, index=None)\n", - " return csv_dataset_path\n", - "\n", - " # 14. Pair Multiple SA Sequences\n", - " elif data_type == \"Multiple pairs of SA Sequences\":\n", - " for i in ['1', '2']:\n", - " protein_df[f'name_{i}'] = f'name_{i}'\n", - " protein_df[f'chain_{i}'] = 'A'\n", - "\n", - " protein_df.to_csv(csv_dataset_path, index=None)\n", - " return csv_dataset_path\n", - "\n", - " # 15. Pair Multiple UniProt IDs\n", - " elif data_type == \"Multiple pairs of UniProt IDs\":\n", - " for i in ['1', '2']:\n", - " protein_list = protein_df.loc[:, f'sequence_{i}'].tolist()\n", - " uniprot2pdb(protein_list)\n", - " protein_df[f'name_{i}'] = protein_list\n", - " protein_list = [(uniprot_id, \"AF2\", \"A\") for uniprot_id in protein_list]\n", - " mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)\n", - " outputs = mprs.run()\n", - "\n", - " protein_df[f'sequence_{i}'] = [output.split(\"\\t\")[1] for output in outputs]\n", - " protein_df[f'chain_{i}'] = 'A'\n", - "\n", - " protein_df.to_csv(csv_dataset_path, index=None)\n", - " return csv_dataset_path\n", - "\n", - " elif data_type == \"Multiple pairs of PDB/CIF Structures\":\n", - " # columns: sequence_1, sequence_2, type_1, type_2, chain_1, chain_2, label, stage\n", - "\n", - " # protein_list = [(uniprot_id, type, chain), ...]\n", - " # protein_list = [item.split('.')[0] for item in protein_df.iloc[:, 0].tolist()]\n", - " # uniprot2pdb(protein_list)\n", - "\n", - " for i in ['1', '2']:\n", - " protein_list = []\n", - " for index, row in protein_df.iterrows():\n", - " assert row[f\"type_{i}\"] in ['PDB', 'AF2'], \"The type of structure must be either \\\"PDB\\\" or \\\"AF2\\\"!\"\n", - " row_tuple = (row[f\"sequence_{i}\"], row[f\"type_{i}\"], row[f\"chain_{i}\"])\n", - " protein_list.append(row_tuple)\n", - " mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)\n", - " outputs = mprs.run()\n", - "\n", - " # add name column, del type column\n", - " protein_df[f'name_{i}'] = protein_df[f'sequence_{i}'].apply(lambda x: x.split('.')[0])\n", - " protein_df.drop(f\"type_{i}\", axis=1, inplace=True)\n", - " protein_df[f'sequence_{i}'] = [output.split(\"\\t\")[1] for output in outputs]\n", - "\n", - " # columns: name_1, name_2, chain_1, chain_2, sequence_1, sequence_2, label, stage\n", - " protein_df.to_csv(csv_dataset_path, index=None)\n", - " return csv_dataset_path\n", - "\n", - " else:\n", - " # 0. Single AA Sequence\n", - " if data_type == 'Single AA Sequence':\n", - " input_seq = raw_data\n", - " aa_seq = input_seq.value\n", - "\n", - " sa_seq = ''\n", - " for aa in aa_seq:\n", - " sa_seq += aa + '#'\n", - " return sa_seq\n", - "\n", - " # 1. Single SA Sequence\n", - " elif data_type == 'Single SA Sequence':\n", - " input_seq = raw_data\n", - " sa_seq = input_seq.value\n", - "\n", - " return sa_seq\n", - "\n", - " # 2. Single UniProt ID\n", - " elif data_type == 'Single UniProt ID':\n", - " input_seq = raw_data\n", - " uniprot_id = input_seq.value\n", - "\n", - "\n", - " protein_list = [(uniprot_id, \"AF2\", \"A\")]\n", - " uniprot2pdb([protein_list[0][0]])\n", - " mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)\n", - " seqs = mprs.run()\n", - " sa_seq = seqs[0].split('\\t')[1]\n", - " return sa_seq\n", - "\n", - " # 3. Single PDB/CIF Structure\n", - " elif data_type == 'Single PDB/CIF Structure':\n", - " uniprot_id = raw_data[0]\n", - " struc_type = raw_data[1].value\n", - " chain = raw_data[2].value\n", - "\n", - " protein_list = [(uniprot_id, struc_type, chain)]\n", - " mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)\n", - " seqs = mprs.run()\n", - " assert len(seqs)>0, \"Unable to convert to SA sequence. Please check the `type`, `chain`, and `.pdb/.cif file`.\"\n", - " sa_seq = seqs[0].split('\\t')[1]\n", - " return sa_seq\n", - "\n", - " # 9. Pair Single AA Sequences\n", - " elif data_type == \"A pair of AA Sequences\":\n", - " input_seq_1, input_seq_2 = raw_data\n", - " sa_seq1 = get_SA_sequence_by_data_type('Single AA Sequence', input_seq_1)\n", - " sa_seq2 = get_SA_sequence_by_data_type('Single AA Sequence', input_seq_2)\n", - "\n", - " return (sa_seq1, sa_seq2)\n", - "\n", - " # 10. Pair Single SA Sequences\n", - " elif data_type == \"A pair of SA Sequences\":\n", - " input_seq_1, input_seq_2 = raw_data\n", - " sa_seq1 = get_SA_sequence_by_data_type('Single SA Sequence', input_seq_1)\n", - " sa_seq2 = get_SA_sequence_by_data_type('Single SA Sequence', input_seq_2)\n", - "\n", - " return (sa_seq1, sa_seq2)\n", - "\n", - " # 11. Pair Single UniProt IDs\n", - " elif data_type == \"A pair of UniProt IDs\":\n", - " input_seq_1, input_seq_2 = raw_data\n", - " sa_seq1 = get_SA_sequence_by_data_type('Single UniProt ID', input_seq_1)\n", - " sa_seq2 = get_SA_sequence_by_data_type('Single UniProt ID', input_seq_2)\n", - "\n", - " return (sa_seq1, sa_seq2)\n", - "\n", - " # 12. Pair Single PDB/CIF Structure\n", - " elif data_type == \"A pair of PDB/CIF Structures\":\n", - " uniprot_id1 = raw_data[0]\n", - " struc_type1 = raw_data[1].value\n", - " chain1 = raw_data[2].value\n", - "\n", - " protein_list1 = [(uniprot_id1, struc_type1, chain1)]\n", - " mprs1 = MultipleProcessRunnerSimplifier(protein_list1, pdb2sequence, n_process=2, return_results=True)\n", - " seqs1 = mprs1.run()\n", - " sa_seq1 = seqs1[0].split('\\t')[1]\n", - "\n", - " uniprot_id2 = raw_data[3]\n", - " struc_type2 = raw_data[4].value\n", - " chain2 = raw_data[5].value\n", - "\n", - " protein_list2 = [(uniprot_id2, struc_type2, chain2)]\n", - " mprs2 = MultipleProcessRunnerSimplifier(protein_list2, pdb2sequence, n_process=2, return_results=True)\n", - " seqs2 = mprs2.run()\n", - " sa_seq2 = seqs2[0].split('\\t')[1]\n", - " return sa_seq1, sa_seq2\n", - "\n", - "\n", - "\n", - "\n", - "################################################################################\n", - "########################## Download predicted structures #######################\n", - "################################################################################\n", - "def uniprot2pdb(uniprot_ids, nprocess=20):\n", - " from saprot.utils.downloader import AlphaDBDownloader\n", - "\n", - " os.makedirs(STRUCTURE_HOME, exist_ok=True)\n", - " af2_downloader = AlphaDBDownloader(uniprot_ids, \"pdb\", save_dir=STRUCTURE_HOME, n_process=20)\n", - " af2_downloader.run()\n", - "\n", - "\n", - "\n", - "################################################################################\n", - "############### Form foldseek sequences by multiple processes ##################\n", - "################################################################################\n", - "# def pdb2sequence(process_id, idx, uniprot_id, writer):\n", - "# from saprot.utils.foldseek_util import get_struc_seq\n", - "\n", - "# try:\n", - "# pdb_path = f\"{STRUCTURE_HOME}/{uniprot_id}.pdb\"\n", - "# cif_path = f\"{STRUCTURE_HOME}/{uniprot_id}.cif\"\n", - "# if Path(pdb_path).exists():\n", - "# seq = get_struc_seq(FOLDSEEK_PATH, pdb_path, [\"A\"], process_id=process_id)[\"A\"][-1]\n", - "# if Path(cif_path).exists():\n", - "# seq = get_struc_seq(FOLDSEEK_PATH, cif_path, [\"A\"], process_id=process_id)[\"A\"][-1]\n", - "\n", - "# writer.write(f\"{uniprot_id}\\t{seq}\\n\")\n", - "# except Exception as e:\n", - "# print(f\"Error: {uniprot_id}, {e}\")\n", - "\n", - "# clear_output(wait=True)\n", - "# print(\"Installation finished!\")\n", - "\n", - "def pdb2sequence(process_id, idx, row_tuple, writer):\n", - "\n", - " # print(\"=\"*100)\n", - " # print(row_tuple)\n", - " # print(\"=\"*100)\n", - " uniprot_id = row_tuple[0].split('.')[0] #\n", - " struc_type = row_tuple[1] # PDB or AF2\n", - " chain = row_tuple[2]\n", - "\n", - " if struc_type==\"AF2\":\n", - " plddt_mask= True\n", - " chain = 'A'\n", - " else:\n", - " plddt_mask= False\n", - "\n", - " from saprot.utils.foldseek_util import get_struc_seq\n", - "\n", - " try:\n", - " pdb_path = f\"{STRUCTURE_HOME}/{uniprot_id}.pdb\"\n", - " cif_path = f\"{STRUCTURE_HOME}/{uniprot_id}.cif\"\n", - " if Path(pdb_path).exists():\n", - " seq = get_struc_seq(FOLDSEEK_PATH, pdb_path, [chain], process_id=process_id, plddt_mask=plddt_mask)[chain][-1]\n", - " elif Path(cif_path).exists():\n", - " seq = get_struc_seq(FOLDSEEK_PATH, cif_path, [chain], process_id=process_id, plddt_mask=plddt_mask)[chain][-1]\n", - " else:\n", - " raise BaseException(f\"The {uniprot_id}.pdb/{uniprot_id}.cif file doesn't exists!\")\n", - " writer.write(f\"{uniprot_id}\\t{seq}\\n\")\n", - "\n", - " except Exception as e:\n", - " print(f\"Error: {uniprot_id}, {e}\")\n", - "\n", - "\n", - "pymol_color_list = [\"#33ff33\",\"#00ffff\",\"#ff33cc\",\"#ffff00\",\"#ff9999\",\"#e5e5e5\",\"#7f7fff\",\"#ff7f00\",\n", - " \"#7fff7f\",\"#199999\",\"#ff007f\",\"#ffdd5e\",\"#8c3f99\",\"#b2b2b2\",\"#007fff\",\"#c4b200\",\n", - " \"#8cb266\",\"#00bfbf\",\"#b27f7f\",\"#fcd1a5\",\"#ff7f7f\",\"#ffbfdd\",\"#7fffff\",\"#ffff7f\",\n", - " \"#00ff7f\",\"#337fcc\",\"#d8337f\",\"#bfff3f\",\"#ff7fff\",\"#d8d8ff\",\"#3fffbf\",\"#b78c4c\",\n", - " \"#339933\",\"#66b2b2\",\"#ba8c84\",\"#84bf00\",\"#b24c66\",\"#7f7f7f\",\"#3f3fa5\",\"#a5512b\"]\n", - "\n", - "alphabet_list = list(ascii_uppercase+ascii_lowercase)\n", - "\n", - "\n", - "def convert_outputs_to_pdb(outputs):\n", - "\tfinal_atom_positions = atom14_to_atom37(outputs[\"positions\"][-1], outputs)\n", - "\toutputs = {k: v.to(\"cpu\").numpy() for k, v in outputs.items()}\n", - "\tfinal_atom_positions = final_atom_positions.cpu().numpy()\n", - "\tfinal_atom_mask = outputs[\"atom37_atom_exists\"]\n", - "\tpdbs = []\n", - "\toutputs[\"plddt\"] *= 100\n", - "\n", - "\tfor i in range(outputs[\"aatype\"].shape[0]):\n", - "\t\taa = outputs[\"aatype\"][i]\n", - "\t\tpred_pos = final_atom_positions[i]\n", - "\t\tmask = final_atom_mask[i]\n", - "\t\tresid = outputs[\"residue_index\"][i] + 1\n", - "\t\tpred = OFProtein(\n", - "\t\t aatype=aa,\n", - "\t\t atom_positions=pred_pos,\n", - "\t\t atom_mask=mask,\n", - "\t\t residue_index=resid,\n", - "\t\t b_factors=outputs[\"plddt\"][i],\n", - "\t\t chain_index=outputs[\"chain_index\"][i] if \"chain_index\" in outputs else None,\n", - "\t\t)\n", - "\t\tpdbs.append(to_pdb(pred))\n", - "\treturn pdbs\n", - "\n", - "\n", - "# This function is copied from ColabFold!\n", - "def show_pdb(path, show_sidechains=False, show_mainchains=False, color=\"lddt\"):\n", - " file_type = str(path).split(\".\")[-1]\n", - " if file_type == \"cif\":\n", - " file_type == \"mmcif\"\n", - "\n", - " view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)\n", - " view.addModel(open(path,'r').read(),file_type)\n", - "\n", - " if color == \"lDDT\":\n", - " view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})\n", - " elif color == \"rainbow\":\n", - " view.setStyle({'cartoon': {'color':'spectrum'}})\n", - " elif color == \"chain\":\n", - " chains = len(get_chain_ids(path))\n", - " for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list):\n", - " view.setStyle({'chain':chain},{'cartoon': {'color':color}})\n", - "\n", - " if show_sidechains:\n", - " BB = ['C','O','N']\n", - " view.addStyle({'and':[{'resn':[\"GLY\",\"PRO\"],'invert':True},{'atom':BB,'invert':True}]},\n", - " {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", - " view.addStyle({'and':[{'resn':\"GLY\"},{'atom':'CA'}]},\n", - " {'sphere':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", - " view.addStyle({'and':[{'resn':\"PRO\"},{'atom':['C','O'],'invert':True}]},\n", - " {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", - " if show_mainchains:\n", - " BB = ['C','O','N','CA']\n", - " view.addStyle({'atom':BB},{'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", - "\n", - " view.zoomTo()\n", - " return view\n", - "\n", - "\n", - "def plot_plddt_legend(dpi=100):\n", - " thresh = ['plDDT:','Very low (<50)','Low (60)','OK (70)','Confident (80)','Very high (>90)']\n", - " plt.figure(figsize=(1,0.1),dpi=dpi)\n", - " ########################################\n", - " for c in [\"#FFFFFF\",\"#FF0000\",\"#FFFF00\",\"#00FF00\",\"#00FFFF\",\"#0000FF\"]:\n", - " plt.bar(0, 0, color=c)\n", - " plt.legend(thresh, frameon=False,\n", - " loc='center', ncol=6,\n", - " handletextpad=1,\n", - " columnspacing=1,\n", - " markerscale=0.5,)\n", - " plt.axis(False)\n", - " return plt\n", - "\n", - "\n", - "################################################################################\n", - "############### Download file to local computer ##################\n", - "################################################################################\n", - "def file_download(path: str):\n", - " with open(path, \"rb\") as r:\n", - " res = r.read()\n", - "\n", - " #FILE\n", - " filename = os.path.basename(path)\n", - " b64 = base64.b64encode(res)\n", - " payload = b64.decode()\n", - "\n", - " #BUTTONS\n", - " html_buttons = '''\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " '''\n", - "\n", - " html_button = html_buttons.format(payload=payload,filename=filename)\n", - " display(HTML(html_button))\n", - "\n", - " # Automatically download file if the server is from google cloud.\n", - " if root_dir == \"/content\":\n", - " files.download(path)\n", - "\n", - "################################################################################\n", - "############################ MODEL INFO #######################################\n", - "################################################################################\n", - "def get_base_model(adapter_path):\n", - " adapter_config = Path(adapter_path) / \"adapter_config.json\"\n", - " with open(adapter_config, 'r') as f:\n", - " adapter_config_dict = json.load(f)\n", - " base_model = adapter_config_dict['base_model_name_or_path']\n", - " if 'SaProt_650M_AF2' in base_model:\n", - " base_model = \"westlake-repl/SaProt_650M_AF2\"\n", - " elif 'SaProt_35M_AF2' in base_model:\n", - " base_model = \"westlake-repl/SaProt_35M_AF2\"\n", - " else:\n", - " raise RuntimeError(\"Please ensure the base model is \\\"SaProt_650M_AF2\\\" or \\\"SaProt_35M_AF2\\\"\")\n", - " return base_model\n", - "\n", - "def check_training_data_type(adapter_path, data_type):\n", - " metadata_path = Path(adapter_path) / \"metadata.json\"\n", - " if metadata_path.exists():\n", - " with open(metadata_path, 'r') as f:\n", - " metadata = json.load(f)\n", - " required_training_data_type = metadata['training_data_type']\n", - " else:\n", - " required_training_data_type = \"SA\"\n", - "\n", - " if (required_training_data_type == \"AA\") and (\"AA\" not in data_type):\n", - " print(Fore.RED+f\"This model ({adapter_path}) is trained on {required_training_data_type} sequences, and predictions work better with AA sequences.\"+Style.RESET_ALL)\n", - " print(Fore.RED+f\"The current data type ({data_type}) includes structural information, which will not be used for predictions.\"+Style.RESET_ALL)\n", - " print()\n", - " print('='*100)\n", - " elif (required_training_data_type == \"SA\") and (\"AA\" in data_type):\n", - " print(Fore.RED+f\"This model ({adapter_path}) is trained on {required_training_data_type} sequences, and predictions work better with SA sequences.\"+Style.RESET_ALL)\n", - " print(Fore.RED+f\"The current data type ({data_type}) does not include structural information, which may lead to weak prediction performance.\"+Style.RESET_ALL)\n", - " print(Fore.RED+f\"If you only have the amino acid sequence, we strongly recommend using AF2 to predict the structure and generate a PDB file before prediction.\"+Style.RESET_ALL)\n", - " print()\n", - " print('='*100)\n", - "\n", - " return required_training_data_type\n", - "\n", - "def mask_struc_token(sequence):\n", - " return ''.join('#' if i % 2 == 1 and char.islower() else char for i, char in enumerate(sequence))\n", - "\n", - "def get_num_labels_by_adapter(adapter_path):\n", - " adapter_path = Path(adapter_path)\n", - "\n", - " if (adapter_path / 'adapter_model.safetensors').exists():\n", - " file_path = adapter_path / 'adapter_model.safetensors'\n", - " with safe_open(file_path, framework=\"pt\") as f:\n", - " if 'base_model.model.classifier.out_proj.bias' in f.keys():\n", - " tensor = f.get_tensor('base_model.model.classifier.out_proj.bias')\n", - " elif 'base_model.model.classifier.bias' in f.keys():\n", - " tensor = f.get_tensor('base_model.model.classifier.bias')\n", - " else:\n", - " raise KeyError(f\"Neither 'base_model.model.classifier.out_proj.bias' nor 'base_model.model.classifier.bias' found in the file({file_path}).\")\n", - "\n", - " elif (adapter_path / 'adapter_model.bin').exists():\n", - " file_path = adapter_path / 'adapter_model.bin'\n", - " state_dict = torch.load(file_path)\n", - " if 'base_model.model.classifier.out_proj.bias' in state_dict.keys():\n", - " tensor = state_dict['base_model.model.classifier.out_proj.bias']\n", - " elif 'base_model.model.classifier.bias' in f.keys():\n", - " tensor = state_dict['base_model.model.classifier.bias']\n", - " else:\n", - " raise KeyError(f\"Neither 'base_model.model.classifier.out_proj.bias' nor 'base_model.model.classifier.bias' found in the file({file_path}).\")\n", - "\n", - " else:\n", - " raise FileNotFoundError(f\"Neither 'adapter_model.safetensors' nor 'adapter_model.bin' found in the provided path({adapter_path}).\")\n", - "\n", - " num_labels = list(tensor.shape)[0]\n", - " return num_labels\n", - "\n", - "def get_num_labels_and_task_type_by_adapter(adapter_path):\n", - " adapter_path = Path(adapter_path)\n", - "\n", - " task_type = None\n", - " if (adapter_path / 'adapter_model.safetensors').exists():\n", - " file_path = adapter_path / 'adapter_model.safetensors'\n", - " with safe_open(file_path, framework=\"pt\") as f:\n", - " if 'base_model.model.classifier.out_proj.bias' in f.keys():\n", - " tensor = f.get_tensor('base_model.model.classifier.out_proj.bias')\n", - " elif 'base_model.model.classifier.bias' in f.keys():\n", - " task_type = 'token_classification'\n", - " tensor = f.get_tensor('base_model.model.classifier.bias')\n", - " else:\n", - " raise KeyError(f\"Neither 'base_model.model.classifier.out_proj.bias' nor 'base_model.model.classifier.bias' found in the file({file_path}).\")\n", - "\n", - " elif (adapter_path / 'adapter_model.bin').exists():\n", - " file_path = adapter_path / 'adapter_model.bin'\n", - " state_dict = torch.load(file_path)\n", - " if 'base_model.model.classifier.out_proj.bias' in state_dict.keys():\n", - " tensor = state_dict['base_model.model.classifier.out_proj.bias']\n", - " elif 'base_model.model.classifier.bias' in f.keys():\n", - " task_type = 'token_classification'\n", - " tensor = state_dict['base_model.model.classifier.bias']\n", - " else:\n", - " raise KeyError(f\"Neither 'base_model.model.classifier.out_proj.bias' nor 'base_model.model.classifier.bias' found in the file({file_path}).\")\n", - "\n", - " else:\n", - " raise FileNotFoundError(f\"Neither 'adapter_model.safetensors' nor 'adapter_model.bin' found in the provided path({adapter_path}).\")\n", - "\n", - " num_labels = list(tensor.shape)[0]\n", - " if task_type != 'token_classification':\n", - " if num_labels > 1:\n", - " task_type = 'classification'\n", - " elif num_labels == 1:\n", - " task_type = 'regression'\n", - "\n", - " return num_labels, task_type\n", - "\n", - "################################################################################\n", - "############################ INFO ##############################################\n", - "################################################################################\n", - "clear_output(wait=True)\n", - "print(\"Installation finished!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3Uxag_RSBI7e" - }, - "source": [ - "# **2: Train and Share your Protein Model** \n", - "\n", - "You can **train** a model based on pre-trained SaProt, or **continually train** a fine-tuned model in SaprotHub.\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8po43iIU7NcG" - }, - "source": [ - "## **2.1: Train your Model** \n", - "\n", - "> 📍Please see the [tutorial](https://github.com/westlake-repl/SaprotHub/wiki/2.1:-Train-your-model)\n", - "\n", - "\n", - "⚠️If you want to **interrupt** the training, **do not** click the run button again. Please refer to [here](https://github.com/westlake-repl/SaprotHub/wiki/2.1:-Train-your-model#interrupt-training-to-avoid-overfitting).\n", - "\n", - "Example datasets are available in at /content/SaprotHub/upload_files/example_csv_dataset and [Github Repository](https://github.com/westlake-repl/SaprotHub/tree/main/upload_files/example_csv_dataset).\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "vqdmLslQBI7e" - }, - "outputs": [], - "source": [ - "##@title **2.1: Train your Model** \n", - "\n", - "################################################################################\n", - "############################# ADVANCED CONFIG ##################################\n", - "################################################################################\n", - "\n", - "# training config\n", - "GPU_batch_size = 0\n", - "accumulate_grad_batches = 0\n", - "num_workers = 2\n", - "seed = 20000812\n", - "\n", - "# lora config\n", - "r = 8\n", - "lora_dropout = 0.0\n", - "lora_alpha = 16\n", - "\n", - "# dataset config\n", - "val_check_interval=0.5\n", - "limit_train_batches=1.0\n", - "limit_val_batches=1.0\n", - "limit_test_batches=1.0\n", - "\n", - "\n", - "mask_struc_ratio=None\n", - "\n", - "################################################################################\n", - "################################## MARKDOWN #################################\n", - "################################################################################\n", - "##@markdown ⚠️If you want to **interrupt** the training, **do not** click the run button again. Please refer to [here](https://github.com/westlake-repl/SaprotHub/wiki/2.1:-Train-your-model#interrupt-training-to-avoid-overfitting).\n", - "\n", - "##@markdown Example datasets are available in at /content/SaprotHub/upload_files/example_csv_dataset and [Github Repository](https://github.com/westlake-repl/SaprotHub/tree/main/upload_files/example_csv_dataset).\n", - "\n", - "##@markdown > 📍Please see the [tutorial](https://github.com/westlake-repl/SaprotHub/wiki/2.1:-Train-your-model)\n", - "\n", - "if torch.cuda.is_available() is False:\n", - " raise BaseException(\"Please refer to Section 1.1 to switch your Runtime to a GPU!\")\n", - "\n", - "################################################################################\n", - "################################## TASK CONFIG #################################\n", - "################################################################################\n", - "#@markdown # 1. Task\n", - "task_name = \"demo\" # @param {type:\"string\"}\n", - "task_type = \"Protein-level Classification\" # @param [\"Protein-level Classification\", \"Protein-level Regression\", \"Residue-level Classification\", \"Protein-protein Classification\", \"Protein-protein Regression\"]\n", - "original_task_type = task_type\n", - "task_type = task_type_dict[task_type]\n", - "\n", - "if task_type in [\"classification\", 'token_classification', 'pair_classification']:\n", - "\n", - " print(Fore.BLUE+'Enter the number of category in your training dataset here:'+Style.RESET_ALL)\n", - " num_of_categories = ipywidgets.BoundedIntText(\n", - " # value=7,\n", - " min=2,\n", - " max=1000000,\n", - " step=1,\n", - " # description='num_of_category: \\n',\n", - " disabled=False)\n", - " num_of_categories.layout.width = \"100px\"\n", - " display(num_of_categories)\n", - "\n", - "#@markdown
\n", - "\n", - "################################################################################\n", - "#################################### MODEL CONFIG #####################################\n", - "################################################################################\n", - "#@markdown # 2. Model\n", - "\n", - "base_model = \"Official pretrained SaProt (35M)\" # @param [\"Official pretrained SaProt (35M)\", \"Official pretrained SaProt (650M)\", \"Trained by yourself on ColabSaprot\", \"Shared by peers on SaprotHub\", \"Saved in your local computer\"]\n", - "\n", - "# continue learning\n", - "if base_model in [\"Trained by yourself on ColabSaprot\", \"Shared by peers on SaprotHub\", \"Saved in your local computer\"]:\n", - " continue_learning = True\n", - " adapter_combobox = select_adapter_from(task_type, use_model_from=base_model)\n", - "else:\n", - " continue_learning = False\n", - "\n", - "#@markdown
\n", - "\n", - "################################################################################\n", - "################################### DATASET CONFIG ####################################\n", - "################################################################################\n", - "#@markdown # 3. Dataset\n", - "\n", - "data_type = \"SaprotHub Dataset\" # @param [\"SaprotHub Dataset\", \"Multiple AA Sequences\", \"Multiple SA Sequences\", \"Multiple UniProt IDs\", \"Multiple PDB/CIF Structures\", \"Multiple pairs of AA Sequences\", \"Multiple pairs of SA Sequences\", \"Multiple pairs of UniProt IDs\", \"Multiple pairs of PDB/CIF Structures\"]\n", - "check_task_type_and_data_type(original_task_type, data_type)\n", - "\n", - "raw_data = input_raw_data_by_data_type(data_type)\n", - "\n", - "#@markdown
\n", - "\n", - "################################################################################\n", - "################################### TRAIN CONFIG ####################################\n", - "################################################################################\n", - "#@markdown # 4. Training\n", - "\n", - "batch_size = \"Adaptive\" # @param [\"Adaptive\", \"1\", \"2\", \"4\", \"8\", \"16\", \"32\", \"64\", \"128\", \"256\"]\n", - "max_epochs = 1 # @param [\"10\", \"20\", \"50\"] {type:\"raw\", allow-input: true}\n", - "learning_rate = 1.0e-3 # @param [\"1.0e-3\", \"5.0e-4\", \"1.0e-4\"] {type:\"raw\", allow-input: true}\n", - "\n", - "#@markdown
\n", - "\n", - "\n", - "################################################################################\n", - "################################# CONFIG #######################################\n", - "################################################################################\n", - "\n", - "from saprot.config.config_dict import Default_config\n", - "config = copy.deepcopy(Default_config)\n", - "\n", - "################################################################################\n", - "################################### TRAIN ####################################\n", - "################################################################################\n", - "\n", - "def train(button):\n", - " global base_model\n", - " global GPU_batch_size\n", - " global accumulate_grad_batches\n", - "\n", - " button.disabled = True\n", - " button.description = 'Training...'\n", - " button.button_style = ''\n", - "\n", - "################################################################################\n", - "################################### DATASET CONFIRM ####################################\n", - "################################################################################\n", - " csv_dataset_path = get_SA_sequence_by_data_type(data_type, raw_data)\n", - " check_column_label_and_stage(csv_dataset_path)\n", - " from saprot.utils.construct_lmdb import construct_lmdb\n", - " construct_lmdb(csv_dataset_path, LMDB_HOME, task_name, task_type)\n", - " lmdb_dataset_path = LMDB_HOME / task_name\n", - "\n", - "################################################################################\n", - "################################### MODEL CONFIRM ####################################\n", - "################################################################################\n", - "\n", - " # base_model\n", - " if continue_learning:\n", - " adapter_path = ADAPTER_HOME / task_type / adapter_combobox.value\n", - " print(f\"Training on an existing model: {adapter_path}\")\n", - "\n", - " if base_model == \"Shared by peers on SaprotHub\":\n", - " if not adapter_path.exists():\n", - " snapshot_download(repo_id=adapter_combobox.value, repo_type=\"model\", local_dir=adapter_path)\n", - "\n", - " adapter_config_path = Path(adapter_path) / \"adapter_config.json\"\n", - " assert adapter_config_path.exists(), f\"Can't find {adapter_config_path}\"\n", - " with open(adapter_config_path, 'r') as f:\n", - " adapter_config = json.load(f)\n", - " base_model = adapter_config['base_model_name_or_path']\n", - "\n", - " elif base_model == \"Official pretrained SaProt (35M)\":\n", - " base_model = \"westlake-repl/SaProt_35M_AF2\"\n", - "\n", - " elif base_model == \"Official pretrained SaProt (650M)\":\n", - " base_model = \"westlake-repl/SaProt_650M_AF2\"\n", - "\n", - " # model size and model name\n", - " if base_model == \"westlake-repl/SaProt_650M_AF2\":\n", - " model_size = \"650M\"\n", - " model_name = f\"Model-{task_name}-{model_size}\"\n", - " elif base_model == \"westlake-repl/SaProt_35M_AF2\":\n", - " model_size = \"35M\"\n", - " model_name = f\"Model-{task_name}-{model_size}\"\n", - "\n", - " config.setting.run_mode = \"train\"\n", - " config.setting.seed = seed\n", - "\n", - "################################################################################\n", - "################################# MODEL ########################################\n", - "################################################################################\n", - "\n", - " if task_type in [\"classification\", \"token_classification\", \"pair_classification\"]:\n", - " config.model.kwargs.num_labels = num_of_categories.value\n", - "\n", - " config.model.model_py_path = model_type_dict[task_type]\n", - " config.model.kwargs.config_path = base_model\n", - " config.dataset.kwargs.tokenizer = base_model\n", - "\n", - " config.model.save_path = str(ADAPTER_HOME / f\"{task_type}\" / \"Local\" / model_name)\n", - "\n", - " if task_type in [\"regression\", \"pair_regression\"]:\n", - " config.model.kwargs.extra_config = {}\n", - " config.model.kwargs.extra_config.attention_probs_dropout_prob=0\n", - " config.model.kwargs.extra_config.hidden_dropout_prob=0\n", - "\n", - " config.model.kwargs.lora_kwargs = EasyDict({\n", - " \"is_trainable\": True,\n", - " \"num_lora\": 1,\n", - " \"r\": r,\n", - " \"lora_dropout\": lora_dropout,\n", - " \"lora_alpha\": lora_alpha,\n", - " \"config_list\": []})\n", - " if continue_learning:\n", - " config.model.kwargs.lora_kwargs.config_list.append({\"lora_config_path\": adapter_path})\n", - "\n", - "################################################################################\n", - "################################# DATASET ######################################\n", - "################################################################################\n", - "\n", - " config.dataset.dataset_py_path = dataset_type_dict[task_type]\n", - "\n", - " config.dataset.train_lmdb = str(lmdb_dataset_path / \"train\")\n", - " config.dataset.valid_lmdb = str(lmdb_dataset_path / \"valid\")\n", - " config.dataset.test_lmdb = str(lmdb_dataset_path / \"test\")\n", - "\n", - " # num_workers\n", - " config.dataset.dataloader_kwargs.num_workers = num_workers\n", - "\n", - " # mask_struc\n", - " # config.dataset.kwargs.mask_struc_ratio= mask_struc_ratio\n", - "\n", - " ################################################################################\n", - " ######################## batch size ############################################\n", - " ################################################################################\n", - " def get_accumulate_grad_samples(num_samples):\n", - " if num_samples > 3200:\n", - " return 64\n", - " elif 1600 < num_samples <= 3200:\n", - " return 32\n", - " elif 800 < num_samples <= 1600:\n", - " return 16\n", - " elif 400 < num_samples <= 800:\n", - " return 8\n", - " elif 200 < num_samples <= 400:\n", - " return 4\n", - " elif 100 < num_samples <= 200:\n", - " return 2\n", - " else:\n", - " return 1\n", - "\n", - " # advanced config\n", - " if (GPU_batch_size > 0) and (accumulate_grad_batches > 0):\n", - " config.dataset.dataloader_kwargs.batch_size = GPU_batch_size\n", - " config.Trainer.accumulate_grad_batches= accumulate_grad_batches\n", - "\n", - " elif (GPU_batch_size == 0) and (accumulate_grad_batches == 0):\n", - "\n", - " # batch_size\n", - " if base_model == \"westlake-repl/SaProt_650M_AF2\" and root_dir == \"/content\":\n", - " GPU_batch_size = 1\n", - " else:\n", - " GPU_batch_size_dict = {\n", - " \"Tesla T4\": 2,\n", - " \"NVIDIA L4\": 2,\n", - " \"NVIDIA A100-SXM4-40GB\": 4,\n", - " }\n", - " GPU_name = torch.cuda.get_device_name(0)\n", - " GPU_batch_size = GPU_batch_size_dict[GPU_name] if GPU_name in GPU_batch_size_dict else 2\n", - "\n", - " if task_type in [\"pair_classification\", \"pair_regression\"]:\n", - " GPU_batch_size = int(max(GPU_batch_size / 2, 1))\n", - "\n", - " config.dataset.dataloader_kwargs.batch_size = GPU_batch_size\n", - "\n", - " # accumulate_grad_batches\n", - " if batch_size == \"Adaptive\":\n", - "\n", - " env = lmdb.open(config.dataset.train_lmdb, readonly=True)\n", - "\n", - " with env.begin() as txn:\n", - " stat = txn.stat()\n", - " num_samples = stat['entries']\n", - "\n", - " accumulate_grad_samples = get_accumulate_grad_samples(num_samples)\n", - "\n", - " else:\n", - " accumulate_grad_samples = int(batch_size)\n", - "\n", - " accumulate_grad_batches = max(int(accumulate_grad_samples / GPU_batch_size), 1)\n", - "\n", - " config.Trainer.accumulate_grad_batches= accumulate_grad_batches\n", - "\n", - " else:\n", - " raise BaseException(f\"Please make sure `GPU_batch_size`({GPU_batch_size}) and `accumulate_grad_batches`({accumulate_grad_batches}) are both greater than zero!\")\n", - "\n", - " ################################################################################\n", - " ############################## TRAINER #########################################\n", - " ################################################################################\n", - "\n", - " config.Trainer.accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"\n", - "\n", - " # epoch\n", - " config.Trainer.max_epochs = max_epochs\n", - " # test only: load the existing model\n", - " if config.Trainer.max_epochs == 0 and continue_learning:\n", - " config.model.save_path = config.model.kwargs.lora_kwargs.config_list[0]['lora_config_path']\n", - "\n", - " # learning rate\n", - " config.model.lr_scheduler_kwargs.init_lr = learning_rate\n", - "\n", - " # trainer\n", - " config.Trainer.limit_train_batches=limit_train_batches\n", - " config.Trainer.limit_val_batches=limit_val_batches\n", - " config.Trainer.limit_test_batches=limit_test_batches\n", - " config.Trainer.val_check_interval=val_check_interval\n", - "\n", - " # strategy\n", - " strategy = {\n", - " # - deepspeed\n", - " # 'class': 'DeepSpeedStrategy',\n", - " # 'stage': 2\n", - "\n", - " # - None\n", - " # 'class': None,\n", - "\n", - " # - DP\n", - " # 'class': 'DataParallelStrategy',\n", - "\n", - " # - DDP\n", - " # 'class': 'DDPStrategy',\n", - " # 'find_unused_parameter': True\n", - " }\n", - " config.Trainer.strategy = strategy\n", - "\n", - " ################################################################################\n", - " ############################## Run the task ####################################\n", - " ################################################################################\n", - "\n", - " print('='*100)\n", - " print(Fore.BLUE+f\"Training task type: {task_type}\"+Style.RESET_ALL)\n", - " print(Fore.BLUE+f\"Dataset: {lmdb_dataset_path}\"+Style.RESET_ALL)\n", - " print(Fore.BLUE+f\"Base Model: {config.model.kwargs.config_path}\"+Style.RESET_ALL)\n", - " if continue_learning:\n", - " print(Fore.BLUE+f\"Existing model: {config.model.kwargs.lora_kwargs.config_list[0]['lora_config_path']}\"+Style.RESET_ALL)\n", - " print('='*100)\n", - " pprint.pprint(config)\n", - " print('='*100)\n", - "\n", - " from saprot.scripts.training import finetune\n", - " finetune(config)\n", - "\n", - "\n", - " ################################################################################\n", - " ############################## Save the adapter ################################\n", - " ################################################################################\n", - "\n", - " def add_training_data_type_to_config(metadata_path, training_data_type):\n", - " if metadata_path.exists() is False:\n", - " config_data = {\n", - " 'training_data_type': training_data_type\n", - " }\n", - " with open(metadata_path, 'w') as file:\n", - " json.dump(config_data, file, indent=4)\n", - "\n", - " else:\n", - " with open(metadata_path, 'r') as file:\n", - " config_data = json.load(file)\n", - "\n", - " config_data['training_data_type'] = training_data_type\n", - "\n", - " with open(metadata_path, 'w') as file:\n", - " json.dump(config_data, file, indent=4)\n", - "\n", - " metadata_path = Path(config.model.save_path) / \"metadata.json\"\n", - " training_data_type = training_data_type_dict[data_type]\n", - " add_training_data_type_to_config(metadata_path, training_data_type)\n", - "\n", - " print(Fore.BLUE)\n", - " print(f\"Model is saved to \\\"{config.model.save_path}\\\" on Colab Server\")\n", - " print(Style.RESET_ALL)\n", - "\n", - "\n", - " adapter_zip = Path(config.model.save_path) / f\"{model_name}.zip\"\n", - " !cd $config.model.save_path && zip -r $adapter_zip \"adapter_config.json\" \"adapter_model.safetensors\" \"README.md\" \"metadata.json\"\n", - " # !cd $config.model.save_path && zip -r $adapter_zip \"adapter_config.json\" \"adapter_model.safetensors\" \"adapter_model.bin\" \"README.md\" \"metadata.json\"\n", - " print(\"Click to download the model to your local computer\")\n", - " if adapter_zip.exists():\n", - " # files.download(adapter_zip)\n", - " file_download(adapter_zip)\n", - "\n", - "\n", - "\n", - " ################################################################################\n", - " ############################### Modify README ##################################\n", - " ################################################################################\n", - " name = model_name\n", - " description = ''\n", - " label_meanings = ''\n", - "\n", - " with open(f'{config.model.save_path}/adapter_config.json', 'r') as f:\n", - " lora_config = json.load(f)\n", - "\n", - " markdown = f'''\n", - "---\n", - "\n", - "base_model: {base_model} \\n\n", - "library_name: peft\n", - "\n", - "---\n", - "\\n\n", - "\n", - "# Model Card for {name}\n", - "{description}\n", - "\n", - "## Task type\n", - "{original_task_type}\n", - "\n", - "## Model input type\n", - "{training_data_type_dict[data_type]} Sequence\n", - "\n", - "## Label meanings\n", - "{label_meanings}\n", - "\n", - "## LoRA config\n", - "\n", - "- **r:** {lora_config['r']}\n", - "- **lora_dropout:** {lora_config['lora_dropout']}\n", - "- **lora_alpha:** {lora_config['lora_alpha']}\n", - "- **target_modules:** {lora_config['target_modules']}\n", - "- **modules_to_save:** {lora_config['modules_to_save']}\n", - "\n", - "## Training config\n", - "\n", - "- **optimizer:**\n", - " - **class:** AdamW\n", - " - **betas:** (0.9, 0.98)\n", - " - **weight_decay:** 0.01\n", - "- **learning rate:** {config.model.lr_scheduler_kwargs.init_lr}\n", - "- **epoch:** {config.Trainer.max_epochs}\n", - "- **batch size:** {config.dataset.dataloader_kwargs.batch_size * config.Trainer.accumulate_grad_batches}\n", - "- **precision:** 16-mixed \\n\n", - "'''\n", - "\n", - " # Write the markdown output to a file\n", - " with open(f\"{config.model.save_path}/README.md\", \"w\") as file:\n", - " file.write(markdown)\n", - "\n", - "\n", - "button_train = ipywidgets.Button(\n", - " description='Start Training',\n", - " disabled=False,\n", - " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", - " tooltip='Apply',\n", - " # icon='check' # (FontAwesome names without the `fa-` prefix)\n", - " )\n", - "button_train.on_click(train)\n", - "button_train.layout.width = '300px'\n", - "display(button_train)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pUP-Iz2Z5pu0" - }, - "source": [ - "## **2.2: Upload your model (Optional)** \n", - "\n", - "\n", - "> 📍Please see the [tutorial](https://github.com/westlake-repl/SaprotHub/wiki/2.2:-Upload-your-model)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "UnKX1BTZBI7f" - }, - "outputs": [], - "source": [ - "#@title **2.2.1: Login Huggingface**\n", - "#@markdown Click the run button to login Huggingface\n", - "################################################################################\n", - "###################### Login HuggingFace #######################################\n", - "################################################################################\n", - "\n", - "from huggingface_hub import notebook_login\n", - "notebook_login()\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "6XlluTsPBI7m" - }, - "outputs": [], - "source": [ - "##@title **2.3: Upload your Model (Optional)**\n", - "#@title **2.2.2: Upload your Model**\n", - "\n", - "\n", - "# #@markdown Your Huggingface adapter repository names follow the format `/`.\n", - "\n", - "################################################################################\n", - "########################## Metadata ###########################################\n", - "################################################################################\n", - "name = \"demo_cls\" # @param {type:\"string\"}\n", - "description = \"This model is used for a demo classification task\" # @param {type:\"string\"}\n", - "\n", - "\n", - "\n", - "# #@markdown > 0: Nucleus
\n", - "# #@markdown > 1: Cytoplasm
\n", - "# #@markdown > 2: Extracellular
\n", - "# #@markdown > ...
\n", - "# #@markdown > 9: Peroxisome
\n", - "\n", - "label_meanings = \"A, B\" #@param {type:\"string\"}\n", - "\n", - "################################################################################\n", - "########################### Move Files ########################################\n", - "################################################################################\n", - "\n", - "from huggingface_hub import HfApi, Repository, ModelFilter\n", - "\n", - "api = HfApi()\n", - "\n", - "user = api.whoami()\n", - "\n", - "if name == \"\":\n", - " name = model_name\n", - "repo_name = user['name'] + '/' + name\n", - "local_dir = Path(\"/content/SaprotHub/model_to_push\") / repo_name\n", - "local_dir.mkdir(parents=True, exist_ok=True)\n", - "\n", - "repo_list = [repo.id for repo in api.list_models(filter=ModelFilter(author=user['name']))]\n", - "if repo_name not in repo_list:\n", - " api.create_repo(repo_name, private=False)\n", - "\n", - "repo = Repository(local_dir=local_dir, clone_from=repo_name)\n", - "\n", - "command = f\"cp {config.model.save_path}/* {local_dir}/\"\n", - "subprocess.run(command, shell=True)\n", - "\n", - "################################################################################\n", - "########################## Modify README ######################################\n", - "################################################################################\n", - "import json\n", - "\n", - "md_path = local_dir / \"README.md\"\n", - "\n", - "\n", - "if task_type in [\"classification\", \"token_classification\", \"pair_classification\"]:\n", - " label_meanings_md = ''\n", - " for index, label in enumerate(label_meanings.split(', ')):\n", - " label_meanings_md += f'''\n", - "{index}: {label.strip()}\n", - "'''\n", - " label_meanings = label_meanings_md\n", - "\n", - "replace_data = {\n", - " '': description,\n", - " '': label_meanings\n", - "}\n", - "\n", - "with open(md_path, \"r\") as file:\n", - " content = file.read()\n", - "\n", - "for key, value in replace_data.items():\n", - " if value != \"\":\n", - " content = content.replace(key, value)\n", - "\n", - "with open(md_path, \"w\") as file:\n", - " file.write(content)\n", - "\n", - "################################################################################\n", - "########################## Upload Model #######################################\n", - "################################################################################\n", - "\n", - "\n", - "repo.push_to_hub(commit_message=\"Upload adapter model\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bQ1JgmrsBI7m" - }, - "source": [ - "# **3: Use SaProt to Predict**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4359yf2P5DAM" - }, - "source": [ - "## **3.1: Classification&Regression Prediction** \n", - "\n", - "> 📍Please see the [tutorial](https://github.com/westlake-repl/SaprotHub/wiki/3.1:-Classification-Regression-Prediction)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "h8qHRJtIQxU4" - }, - "outputs": [], - "source": [ - "from transformers import EsmTokenizer\n", - "import torch\n", - "import copy\n", - "import sys\n", - "from saprot.scripts.training import my_load_model\n", - "\n", - "################################################################################\n", - "################################# TASK #########################################\n", - "################################################################################\n", - "#@markdown # 1. Task\n", - "\n", - "task_type = \"Protein-level Classification\" # @param [\"Protein-level Classification\", \"Protein-level Regression\", \"Residue-level Classification\", \"Protein-protein Classification\", \"Protein-protein Regression\"]\n", - "original_task_type = task_type\n", - "task_type = task_type_dict[task_type]\n", - "\n", - "if task_type in [\"classification\", 'token_classification', 'pair_classification']:\n", - "\n", - " print(Fore.BLUE+'The number of categories in your classification task:'+Style.RESET_ALL)\n", - " num_of_categories = ipywidgets.BoundedIntText(\n", - " # value=7,\n", - " min=2,\n", - " # max=10,\n", - " step=1,\n", - " # description='num_of_category: \\n',\n", - " disabled=False)\n", - " num_of_categories.layout.width = \"100px\"\n", - " display(num_of_categories)\n", - "\n", - "#@markdown
\n", - "\n", - "\n", - "################################################################################\n", - "################################## MODEL #######################################\n", - "################################################################################\n", - "#@markdown # 2. Model\n", - "\n", - "use_model_from = \"Trained by yourself on ColabSaprot\" # @param [\"Trained by yourself on ColabSaprot\", \"Shared by peers on SaprotHub\", \"Saved in your local computer\", \"Multi-models on SaprotHub\"]\n", - "if use_model_from == \"Multi-models on SaprotHub\":\n", - " multi_lora = True\n", - "else:\n", - " multi_lora = False\n", - "\n", - "adapter_input = select_adapter_from(task_type, use_model_from)\n", - "#@markdown
\n", - "\n", - "################################################################################\n", - "################################ DATASET #######################################\n", - "################################################################################\n", - "#@markdown # 3. Dataset\n", - "data_type = \"Single AA Sequence\" # @param [\"Single AA Sequence\", \"Single SA Sequence\", \"Single UniProt ID\", \"Single PDB/CIF Structure\", \"Multiple AA Sequences\", \"Multiple SA Sequences\", \"Multiple UniProt IDs\", \"Multiple PDB/CIF Structures\", \"A pair of AA Sequences\", \"A pair of SA Sequences\", \"A pair of UniProt IDs\", \"A pair of PDB/CIF Structures\", \"Multiple pairs of AA Sequences\", \"Multiple pairs of SA Sequences\", \"Multiple pairs of UniProt IDs\", \"Multiple pairs of PDB/CIF Structures\"]\n", - "check_task_type_and_data_type(original_task_type, data_type)\n", - "\n", - "mode = \"Multiple Sequences\" if (data_type in data_type_list_multiple) else \"Single Sequence\"\n", - "\n", - "raw_data = input_raw_data_by_data_type(data_type)\n", - "\n", - "################################################################################\n", - "##################################### PREDICT ###################################\n", - "################################################################################\n", - "def predict(button):\n", - " button.disabled = True\n", - " button.description = 'Predicting...'\n", - " button.button_style = ''\n", - "\n", - " print('\\n')\n", - " print('='*100)\n", - "\n", - " ##############################################################################\n", - " ################################# MODEL ###################################\n", - " ##############################################################################\n", - " if multi_lora:\n", - " if use_model_from == \"Multi-models on ColabSaprot\":\n", - " config_list = [EasyDict({'lora_config_path': ADAPTER_HOME / task_type / lora_config_path}) for lora_config_path in list(adapter_input.value)]\n", - " elif use_model_from == \"Multi-models on SaprotHub\":\n", - " #1. get adapter_list\n", - " repo_id_list = adapter_input.value.replace(\" \", \"\").split(',')\n", - " #2. download adapters\n", - " for repo_id in repo_id_list:\n", - " snapshot_download(repo_id=repo_id, repo_type=\"model\", local_dir=ADAPTER_HOME / task_type / repo_id)\n", - " config_list = [EasyDict({'lora_config_path': ADAPTER_HOME / task_type / repo_id}) for repo_id in repo_id_list]\n", - "\n", - " assert len(config_list) > 0, \"Please select your models from the dropdown menu on the output of 3.1!\"\n", - " base_model = get_base_model(ADAPTER_HOME / task_type / config_list[0].lora_config_path)\n", - "\n", - " required_training_data_type_list = []\n", - " for lora_config in config_list:\n", - " required_training_data_type_list.append(check_training_data_type(lora_config.lora_config_path, data_type))\n", - " assert len(set(required_training_data_type_list)) == 1, f\"Error: The input data types of these models are not identical: {required_training_data_type_list}\"\n", - " required_training_data_type = required_training_data_type_list[0]\n", - "\n", - " lora_kwargs = EasyDict({\n", - " \"is_trainable\": False,\n", - " \"num_lora\": len(config_list),\n", - " \"config_list\": config_list\n", - " })\n", - "\n", - " else:\n", - " if use_model_from == \"Shared by peers on SaprotHub\":\n", - " snapshot_download(repo_id=adapter_input.value, repo_type=\"model\", local_dir=ADAPTER_HOME / task_type / adapter_input.value)\n", - "\n", - " adapter_path = ADAPTER_HOME / task_type / adapter_input.value\n", - " base_model = get_base_model(adapter_path)\n", - " required_training_data_type = check_training_data_type(adapter_path, data_type)\n", - " lora_kwargs = {\n", - " \"is_trainable\": False,\n", - " \"num_lora\": 1,\n", - " \"config_list\": [{\"lora_config_path\": adapter_path}]\n", - " }\n", - "\n", - " ##############################################################################\n", - " ################################# DATASET ###################################\n", - " ##############################################################################\n", - " if data_type in data_type_list_multiple:\n", - " csv_dataset_path = get_SA_sequence_by_data_type(data_type, raw_data)\n", - " df = read_csv_dataset(csv_dataset_path)\n", - " else:\n", - " single_sa_seq = get_SA_sequence_by_data_type(data_type, raw_data)\n", - " if task_type in [\"pair_classification\", \"pair_regression\"]:\n", - " df = pd.DataFrame({\n", - " 'sequence_1': [single_sa_seq[0]],\n", - " 'sequence_2': [single_sa_seq[1]]\n", - " })\n", - " else:\n", - " df = pd.DataFrame({\n", - " 'sequence': [single_sa_seq]\n", - " })\n", - "\n", - " if (required_training_data_type == \"AA\") and (\"AA\" not in data_type):\n", - " if 'sequence' in df.columns:\n", - " df['sequence'] = df['sequence'].apply(mask_struc_token)\n", - " elif 'sequence_1' in df.columns and 'sequence_2' in df.columns:\n", - " df['sequence_1'] = df['sequence_1'].apply(mask_struc_token)\n", - " df['sequence_2'] = df['sequence_2'].apply(mask_struc_token)\n", - "\n", - " ################################################################################\n", - " ##################################### CONFIG ###################################\n", - " ################################################################################\n", - " from saprot.config.config_dict import Default_config\n", - " config = copy.deepcopy(Default_config)\n", - "\n", - " # task\n", - " if task_type in [ \"classification\", \"token_classification\", \"pair_classification\"]:\n", - " config.model.kwargs.num_labels = num_of_categories.value\n", - " # base model\n", - " config.model.model_py_path = model_type_dict[task_type]\n", - " config.model.kwargs.config_path = base_model\n", - " # lora\n", - " config.model.kwargs.lora_kwargs = lora_kwargs\n", - "\n", - " ################################################################################\n", - " ################################### LOAD MODEL ##################################\n", - " ################################################################################\n", - " model = my_load_model(config.model)\n", - " tokenizer = EsmTokenizer.from_pretrained(config.model.kwargs.config_path)\n", - " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - " model.to(device)\n", - "\n", - " ################################################################################\n", - " ################################### INFO #######################################\n", - " ################################################################################\n", - " # clear_output(wait=True)\n", - " print('\\n')\n", - " print('='*100)\n", - "\n", - " print(Fore.BLUE+f\"Task Type: {original_task_type}\"+Style.RESET_ALL)\n", - "\n", - " print(Fore.BLUE+f\"Model ({use_model_from}):\"+Style.RESET_ALL)\n", - " if multi_lora:\n", - " print(Fore.BLUE+f\" Base Model: {base_model}\"+Style.RESET_ALL)\n", - " print(Fore.BLUE+f\" Adapter:\"+Style.RESET_ALL)\n", - " for lora_config in lora_kwargs.config_list:\n", - " print(Fore.BLUE+f\" {lora_config.lora_config_path}\"+Style.RESET_ALL)\n", - " else:\n", - " print(Fore.BLUE+f\" Base Model: {base_model}\"+Style.RESET_ALL)\n", - " print(Fore.BLUE+f\" Adapter: {adapter_path}\"+Style.RESET_ALL)\n", - "\n", - " print(Fore.BLUE+f'Dataset ({data_type}):' +Style.RESET_ALL)\n", - " if mode == \"Multiple Sequences\":\n", - " print(Fore.BLUE+f\" CSV Dataset Path: {csv_dataset_path}\"+Style.RESET_ALL)\n", - " else:\n", - " if \"A pair of\" in data_type:\n", - " print(Fore.BLUE+f\" Sequence 1: {single_sa_seq[0]}\"+Style.RESET_ALL)\n", - " print(Fore.BLUE+f\" Sequence 2: {single_sa_seq[1]}\"+Style.RESET_ALL)\n", - " else:\n", - " print(Fore.BLUE+f\" Sequence: {single_sa_seq}\"+Style.RESET_ALL)\n", - "\n", - " ################################################################################\n", - " ################################### INFERENCE ##################################\n", - " ################################################################################\n", - " print()\n", - " print('='*100)\n", - " print(Fore.BLUE+\"Prediction Result:\"+Style.RESET_ALL)\n", - "\n", - " outputs_list=[]\n", - " if task_type in [\"pair_classification\", \"pair_regression\"]:\n", - " for index, row in tqdm(df.iterrows(), total=df.shape[0]):\n", - " input_1 = tokenizer(row['sequence_1'], return_tensors=\"pt\")\n", - " input_1 = {k: v.to(device) for k, v in input_1.items()}\n", - " input_2 = tokenizer(row['sequence_2'], return_tensors=\"pt\")\n", - " input_2 = {k: v.to(device) for k, v in input_2.items()}\n", - "\n", - " with torch.no_grad(): outputs = model(input_1, input_2)\n", - " outputs_list.append(outputs)\n", - " else:\n", - " for index in tqdm(range(len(df))):\n", - " seq = df['sequence'].iloc[index]\n", - " inputs = tokenizer(seq, return_tensors=\"pt\")\n", - " inputs = {k: v.to(device) for k, v in inputs.items()}\n", - " with torch.no_grad(): outputs = model(inputs)\n", - " outputs_list.append(outputs)\n", - "\n", - " ################################################################################\n", - " ################################### RESULT ##################################\n", - " ################################################################################\n", - " timestamp = str(datetime.now().strftime(\"%Y%m%d%H%M%S\"))\n", - " output_file = OUTPUT_HOME / f'output_{timestamp}.csv'\n", - "\n", - " if task_type == \"pair_classification\":\n", - " softmax_output_list = [F.softmax(output, dim=1).squeeze().tolist() for output in outputs_list]\n", - " print()\n", - " for index, output in enumerate(softmax_output_list):\n", - " print(f\"For Sequence Pair {index}, Category {output.index(max(output))}, Probability: {output}\")\n", - " df.loc[index, 'result'] = output.index(max(output))\n", - " df.loc[index, 'probability'] = ', '.join(map(str, output))\n", - " df.to_csv(output_file, index=False)\n", - "\n", - " elif task_type == \"pair_regression\":\n", - " print()\n", - " for index, output in enumerate(outputs_list):\n", - " print(f\"For Sequence Pair {index}, Value {output.cpu().item()}\")\n", - " df['score'] = [output.cpu().item() for output in outputs_list]\n", - " df.to_csv(output_file, index=False)\n", - "\n", - " elif task_type == \"classification\":\n", - " print()\n", - " softmax_output_list = [F.softmax(output, dim=1).squeeze().tolist() for output in outputs_list]\n", - " for index, output in enumerate(softmax_output_list):\n", - " print(f\"For Sequence {index}, Category {output.index(max(output))}, Probability: {output}\")\n", - " df.loc[index, 'result'] = output.index(max(output))\n", - " df.loc[index, 'probability'] = ', '.join(map(str, output))\n", - " df.to_csv(output_file, index=False)\n", - "\n", - " elif task_type == \"regression\":\n", - " print()\n", - " for index, output in enumerate(outputs_list):\n", - " print(f\"For Sequence {index}, Value {output.item()}\")\n", - " df['score'] = [output.cpu().item() for output in outputs_list]\n", - " df.to_csv(output_file, index=False)\n", - "\n", - " elif task_type == \"token_classification\":\n", - " seq_prob_df_list = []\n", - " softmax_output_list = [F.softmax(output, dim=-1).squeeze().tolist() for output in outputs_list]\n", - " # print(\"The probability of each category:\")\n", - " for seq_index, seq in enumerate(softmax_output_list):\n", - " seq_prob_df = pd.DataFrame(seq)[1:-1]\n", - " # print('='*100)\n", - " # print(f'Sequence {seq_index + 1}:')\n", - " # print(seq_prob_df.to_string())\n", - " seq_prob_df['seq_index'] = seq_index\n", - " seq_prob_df['aa_index'] = seq_prob_df.index\n", - " seq_prob_df['sequence'] = df.loc[seq_index, 'sequence']\n", - " seq_prob_df_list.append(seq_prob_df)\n", - " combined_df = pd.concat(seq_prob_df_list, ignore_index=False)\n", - " combined_df.to_csv(output_file, index=True)\n", - "\n", - " print()\n", - " print('='*100)\n", - " print(Fore.BLUE+f\"The prediction result is saved to {output_file} and your local computer.\"+Style.RESET_ALL)\n", - " file_download(output_file)\n", - "\n", - "################################################################################\n", - "#################################### BUTTON #################################\n", - "################################################################################\n", - "button_predict = ipywidgets.Button(\n", - " description='Make Prediction',\n", - " disabled=False,\n", - " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", - " tooltip='Apply',\n", - " # icon='check' # (FontAwesome names without the `fa-` prefix)\n", - " )\n", - "button_predict.on_click(predict)\n", - "# button_predict.layout.width = '500px'\n", - "display(button_predict)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mPtlNrDl5Sru" - }, - "source": [ - "## **3.2: Zero-shot Mutational Effect Prediction** \n", - "\n", - "> 📍Please see the [tutorial](https://github.com/westlake-repl/SaprotHub/wiki/3.2:-Mutational-Effect-Prediction)\n", - "\n", - "\n", - "The model takes the \"wild type sequence\" and [\"mutation information\"](https://github.com/westlake-repl/SaprotHub/wiki/3.2:-Mutational-Effect-Prediction#mutation-information) as **input** and **outputs** a \"score\".\n", - "A **positive score** means the mutation is **better** than the wild type from evolution perspective.\n", - "\n", - "Our model is pre-trained based on protein structures and performs best when provided with structural data.\n", - "\n", - "**For this task, if you only have the protein AA sequence, we strongly recommend using AF2 to predict its structure and then using the SA sequence as input.**\n", - "\n", - "Here you can **convert your data into SA Sequence** format.\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "uxD_KOF1BI7n" - }, - "outputs": [], - "source": [ - "\n", - "mutation_task = \"Saturation mutagenesis\" #@param [\"Single-site or Multi-site mutagenesis\", \"Saturation mutagenesis\"]\n", - "\n", - "# data_type = \"Single AA Sequence\" # @param [\"Single AA Sequence\", \"Single SA Sequence\", \"Single UniProt ID\", \"Single PDB/CIF Structure\", \"Multiple AA Sequences\", \"Multiple SA Sequences\", \"Multiple UniProt IDs\", \"Multiple PDB/CIF Structures\"]\n", - "data_type = \"Single SA Sequence\" # @param [\"Single SA Sequence\", \"Single UniProt ID\", \"Single PDB/CIF Structure\", \"Multiple SA Sequences\", \"Multiple UniProt IDs\", \"Multiple PDB/CIF Structures\"]\n", - "raw_data = input_raw_data_by_data_type(data_type)\n", - "\n", - "mode = \"Multiple Sequences\" if data_type in data_type_list_multiple else \"Single Sequence\"\n", - "\n", - "if mutation_task == \"Single-site or Multi-site mutagenesis\":\n", - " if mode == \"Single Sequence\":\n", - " input_mut = ipywidgets.Text(\n", - " value=None,\n", - " placeholder='Enter Single Mutation Information here',\n", - " # description='SA Sequence:',\n", - " disabled=False)\n", - " print(Fore.BLUE+\"Mutation:\"+Style.RESET_ALL)\n", - " input_mut.layout.width = '500px'\n", - " display(input_mut)\n", - "\n", - "def mutational_effect_predict(button):\n", - " button.disabled = True\n", - " button.description = 'Clicked'\n", - " button.button_style = ''\n", - "\n", - "\n", - " #@title 3.2.2: Get your Result\n", - "\n", - " ################################################################################\n", - " ################################# DATASET ###################################\n", - " ################################################################################\n", - " if mode == \"Single Sequence\":\n", - " seq = get_SA_sequence_by_data_type(data_type, raw_data)\n", - " else:\n", - " dataset_csv_path = get_SA_sequence_by_data_type(data_type, raw_data)\n", - "\n", - " ################################################################################\n", - " ################################# Task Info ####################################\n", - " ################################################################################\n", - " base_model = \"westlake-repl/SaProt_650M_AF2\"\n", - "\n", - " # clear_output(wait=True)\n", - "\n", - " print(Fore.BLUE)\n", - " print(f\"Mutation task: {mutation_task}\")\n", - " print(f\"Mode: {mode}\")\n", - " print(f\"Model: {base_model}\")\n", - " if mode == \"Multiple Sequences\":\n", - " print(Fore.BLUE+f\"Dataset: {dataset_csv_path}\"+Style.RESET_ALL)\n", - " else:\n", - " print(Fore.BLUE+f\"Dataset: {seq}\"+Style.RESET_ALL)\n", - "\n", - " print(Style.RESET_ALL)\n", - "\n", - " print(f\"Predicting...\")\n", - " timestamp = datetime.now().strftime(\"%y%m%d%H%M%S\")\n", - "\n", - " ################################################################################\n", - " ################################# load model ###################################\n", - " ################################################################################\n", - "\n", - " from saprot.model.saprot.saprot_foldseek_mutation_model import SaprotFoldseekMutationModel\n", - "\n", - " config = {\n", - " \"foldseek_path\": None,\n", - " \"config_path\": base_model,\n", - " \"load_pretrained\": True,\n", - " }\n", - "\n", - " try:\n", - " zero_shot_model\n", - " except Exception:\n", - " zero_shot_model = SaprotFoldseekMutationModel(**config)\n", - " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - " zero_shot_model.to(device)\n", - "\n", - " ################################################################################\n", - " ########################### Single Sequence ####################################\n", - " ################################################################################\n", - " if mode == \"Single Sequence\":\n", - "\n", - " if mutation_task == \"Single-site or Multi-site mutagenesis\":\n", - " mut = input_mut.value\n", - " # validate mut\n", - " aa_seq = seq[0::2]\n", - " for m in mut.split(':'):\n", - " ori_aa = m[0]\n", - " pos = int(m[1:-1])\n", - " mut_aa = m[-1]\n", - " assert aa_seq[pos-1] == ori_aa, f\"The provided mutation information contains an error ({m}): the original amino acid at position {pos} ({ori_aa}) does not match your sequence ({aa_seq[pos-1]}).\"\n", - "\n", - " score = zero_shot_model.predict_mut(seq, mut)\n", - "\n", - " print()\n", - " print(\"=\"*100)\n", - " print(Fore.BLUE+\"Output:\"+Style.RESET_ALL)\n", - " print(f\"The score of mutation {mut} is {Fore.BLUE}{score}{Style.RESET_ALL}\")\n", - "\n", - " if mutation_task==\"Saturation mutagenesis\":\n", - " timestamp = datetime.now().strftime(\"%y%m%d%H%M%S\")\n", - " output_path = OUTPUT_HOME / f'{timestamp}_prediction_output.csv'\n", - "\n", - " mut_dicts = []\n", - " for pos in tqdm(range(1, int(len(seq) / 2)+1), total=int(len(seq) / 2)+1, leave=False, desc=f\"Predicting\"):\n", - " mut_dict = zero_shot_model.predict_pos_mut(seq, pos)\n", - " mut_dicts.append(mut_dict)\n", - "\n", - " mut_list = [{'mutation': key, 'score': value} for d in mut_dicts for key, value in d.items()]\n", - " df = pd.DataFrame(mut_list)\n", - " df.to_csv(output_path, index=None)\n", - "\n", - " print()\n", - " print(\"=\"*100)\n", - " print(Fore.BLUE+\"Output:\"+Style.RESET_ALL)\n", - " # files.download(output_path)\n", - " file_download(output_path)\n", - " print(f\"\\n{Fore.BLUE}The result has been saved to {output_path} and your local computer.{Style.RESET_ALL}\")\n", - "\n", - " ################################################################################\n", - " ########################### Multiple Sequences #################################\n", - " ################################################################################\n", - " if mode == \"Multiple Sequences\":\n", - "\n", - " dataset_df = read_csv_dataset(dataset_csv_path)\n", - " results = []\n", - "\n", - " if mutation_task==\"Single-site or Multi-site mutagenesis\":\n", - " for index, row in tqdm(dataset_df.iterrows(), total=len(dataset_df), leave=False, desc=f\"Predicting\"):\n", - " seq = row['sequence']\n", - " mut_info = row['mutation']\n", - " results.append(zero_shot_model.predict_mut(seq, mut_info).cpu().item())\n", - "\n", - " print()\n", - " print(\"=\"*100)\n", - " print(Fore.BLUE+\"Output:\"+Style.RESET_ALL)\n", - "\n", - " # result_df = pd.DataFrame()\n", - " # result_df['sequence'] = dataset_df['sequence']\n", - " # result_df['mutation'] = dataset_df['mutation']\n", - " dataset_df['score'] = results\n", - "\n", - " output_path = OUTPUT_HOME / f\"{timestamp}_prediction_output_{Path(dataset_csv_path).stem}.csv\"\n", - " dataset_df.to_csv(output_path, index=None)\n", - " file_download(output_path)\n", - " print(f\"{Fore.BLUE}The result has been saved to {output_path} and your local computer {Style.RESET_ALL}\")\n", - "\n", - " else:\n", - " for index, row in tqdm(dataset_df.iterrows(), total=len(dataset_df), leave=False, desc=f\"Predicting\"):\n", - " seq = row['sequence']\n", - " mut_dicts = []\n", - " for pos in range(1, int(len(seq) / 2)+1):\n", - " mut_dict = zero_shot_model.predict_pos_mut(seq, pos)\n", - " mut_dicts.append(mut_dict)\n", - " mut_list = [{'mutation': key, 'score': value} for d in mut_dicts for key, value in d.items()]\n", - " result_df = pd.DataFrame(mut_list)\n", - " results.append(result_df)\n", - "\n", - " print()\n", - " print(\"=\"*100)\n", - " print(Fore.BLUE+\"Output:\"+Style.RESET_ALL)\n", - "\n", - " zip_files = []\n", - " for i in range(len(results)):\n", - " output_path = OUTPUT_HOME / f\"{timestamp}_prediction_output_{Path(dataset_csv_path).stem}_Sequence{i+1}.csv\"\n", - " results[i].to_csv(output_path, index=None)\n", - " zip_files.append(output_path)\n", - "\n", - " # zip and download zip to local computer\n", - " zip_path = OUTPUT_HOME / f\"{timestamp}_{Path(dataset_csv_path).stem}.zip\"\n", - " with zipfile.ZipFile(zip_path, 'w') as zipf:\n", - " for file in zip_files:\n", - " zipf.write(file, os.path.basename(file))\n", - " # files.download(zip_path)\n", - " print(f\"{Fore.BLUE}The result has been saved to {zip_path} and your local computer{Style.RESET_ALL}\")\n", - " file_download(zip_path)\n", - "\n", - "button_mutational_effect_predict = ipywidgets.Button(\n", - " description='Mutational Effect Predict',\n", - " disabled=False,\n", - " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", - " tooltip='Apply',\n", - " # icon='check' # (FontAwesome names without the `fa-` prefix)\n", - " )\n", - "button_mutational_effect_predict.on_click(mutational_effect_predict)\n", - "button_mutational_effect_predict.layout.width = '300px'\n", - "display(button_mutational_effect_predict)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uAlQdqTcBI7n" - }, - "source": [ - "## **3.3: Inverse Folding Prediction** \n", - "\n", - "\n", - "\n", - "> 📍Please see the [tutorial](https://github.com/westlake-repl/SaprotHub/wiki/3.3:-Inverse-Folding-Prediction)\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "DT7M_DU2BI7n" - }, - "outputs": [], - "source": [ - "#@title **3.3.1: Upload .pdb/.cif structure file**\n", - "\n", - "#@markdown After clicking the run button, an upload button will appear for you to upload your .pdb/.cif structure file.\n", - "\n", - "#@markdown Note:since you may not know the AA type, you can simply populate your .pdb/.cif file with any random AA. If you want to predit partial positions given some accurate AA information in other positions, just input the accurate AA in these positions and any random AA in unknown positions.\n", - "\n", - "#@markdown After uploading is finished, the .pdb/.cif structure will be transformed into the corresponding AA Sequence and Structure (3Di) Sequence.\n", - "\n", - "data_type = \"Single PDB/CIF Structure\"\n", - "# raw_data = input_raw_data_by_data_type(data_type)\n", - "\n", - "def get_structure_file():\n", - " print(\"Please provide the structure type, chain and your structure file.\")\n", - "\n", - " dropdown_type = ipywidgets.Dropdown(\n", - " value=\"PDB\",\n", - " options=[\"PDB\", \"AF2\"],\n", - " disabled=False)\n", - " dropdown_type.layout.width = '500px'\n", - " print(Fore.BLUE+\"Structure type:\"+Style.RESET_ALL)\n", - " display(dropdown_type)\n", - "\n", - " input_chain = ipywidgets.Text(\n", - " value=\"A\",\n", - " placeholder=f'Enter the name of chain here',\n", - " disabled=False)\n", - " input_chain.layout.width = '500px'\n", - " print(Fore.BLUE+\"Chain:\"+Style.RESET_ALL)\n", - " display(input_chain)\n", - "\n", - " print(Fore.BLUE+\"Please upload a .pdb/.cif file\"+Style.RESET_ALL)\n", - " pdb_file_path = upload_file(STRUCTURE_HOME)\n", - " return pdb_file_path, pdb_file_path.stem, dropdown_type, input_chain\n", - "\n", - "\n", - "backbone_path, stem, dropdown_type, input_chain = get_structure_file()\n", - "raw_data = (stem, dropdown_type, input_chain)\n", - "\n", - "sa_seq = get_SA_sequence_by_data_type(data_type, raw_data)\n", - "\n", - "aa_seq = sa_seq[0::2]\n", - "struc_seq = sa_seq[1::2]\n", - "\n", - "# masked_sa_seq = ''\n", - "# for s in sa_seq[1::2]:\n", - "# masked_sa_seq += '#' + s\n", - "\n", - "clear_output(wait=True)\n", - "\n", - "################################################################################\n", - "################################################################################\n", - "################################################################################\n", - "\n", - "print()\n", - "\n", - "input_aa_seq = ipywidgets.Text(\n", - " value=aa_seq,\n", - " placeholder='Enter Amino Acid Sequence here',\n", - " disabled=False)\n", - "print(Fore.BLUE+\"Amino Acid Sequence:\"+Style.RESET_ALL)\n", - "input_aa_seq.layout.width = '500px'\n", - "display(input_aa_seq)\n", - "\n", - "input_struc_seq = ipywidgets.Text(\n", - " value=struc_seq,\n", - " placeholder='Enter Structure Sequence here',\n", - " disabled=False)\n", - "print(Fore.BLUE+\"Structure Sequence:\"+Style.RESET_ALL)\n", - "input_struc_seq.layout.width = '500px'\n", - "display(input_struc_seq)\n", - "\n", - "# print(Fore.RED+\"If you want to mask all amino acids and make prediction, simply clear the 'Amino Acid Sequence' box.\")\n", - "\n", - "backbone_name = os.path.basename(backbone_path)\n", - "show_pdb(backbone_path, color=\"chain\").show()\n", - "print(f\"Backbone visualization of {backbone_name} ({len(struc_seq)} amino acids)\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "baiH-BrBl2Ge" - }, - "outputs": [], - "source": [ - "#@title **3.3.2: Predict Amino Acid Sequence**\n", - "\n", - "#@markdown You can **mask partial or all amino acids** in the AA sequence with '#' at certain positions, allowing the model to make predictions for those masked amino acids.\n", - "\n", - "#@markdown If you want to **mask all amino acids** and make prediction, simply clear the 'masked_aa_seq' box.\n", - "\n", - "#@markdown | Original AA Sequence | Masked AA Sequence | Description |\n", - "#@markdown | -------------------- | ------------------ | ------------------------------------------ |\n", - "#@markdown | MEETMKLATMEDTVEYCL | ME#T#KL#TMEDTVEYCL | Predict the 3rd, 5th, and 8th amino acids. |\n", - "#@markdown | MEETMKLATMEDTVEYCL | | Predict all amino acids. |\n", - "\n", - "#@markdown
\n", - "\n", - "# #@markdown Click the run button to get the predicted Amino Acid Sequence\n", - "masked_aa_seq = \"\" # @param {type:\"string\", placeholder:\"mask the amino acids with `#` and then paste the sequence here\"}\n", - "method = \"multinomial\" # @param [\"argmax\", \"multinomial\"]\n", - "num_samples = 10 # @param {type:\"integer\"}\n", - "\n", - "#@markdown - `method` refers to the prediction method. It could be either \"argmax\" or \"multinomial\".\n", - "#@markdown - `argmax` selects the amino acid with the highest probability.\n", - "#@markdown - `multinomial` samples an amino acid from the multinomial distribution.\n", - "\n", - "\n", - "#@markdown - `num_samples` refers to the number of output amino acid sequences.\n", - "\n", - "save_name = \"predicted_seq\" # @param {type:\"string\"}\n", - "\n", - "\n", - "\n", - "################################################################################\n", - "############################### Dataset ########################################\n", - "################################################################################\n", - "\n", - "# masked_aa_seq = input_aa_seq.value\n", - "if masked_aa_seq.strip() == \"\":\n", - " masked_aa_seq = \"#\" * len(input_struc_seq.value)\n", - "\n", - "masked_struc_seq = input_struc_seq.value\n", - "\n", - "# assert len(masked_aa_seq) == len(masked_struc_seq), f\"Please make sure that the amino acid sequence ({len(masked_aa_seq)}) and the structure sequence ({len(masked_struc_seq)}) have the same length.\"\n", - "# masked_sa_seq = ''.join(a + b for a, b in zip(masked_aa_seq, masked_struc_seq))\n", - "\n", - "\n", - "# if num_samples == 1:\n", - "# method = \"argmax\"\n", - "# elif num_samples > 1:\n", - "# method = \"multinomial\"\n", - "# else:\n", - "# raise BaseException(\"\\\"num_samples\\\" should be an integer greater than or equal to 1.\")\n", - "\n", - "################################################################################\n", - "############################### Model ##########################################\n", - "################################################################################\n", - "# base_model = \"westlake-repl/SaProt_650M_AF2\"\n", - "base_model = \"westlake-repl/SaProt_650M_AF2_inverse_folding\"\n", - "\n", - "config = {\n", - " \"config_path\": base_model,\n", - " \"load_pretrained\": True,\n", - "}\n", - "from saprot.model.saprot.saprot_if_model import SaProtIFModel\n", - "try:\n", - " saprot_if_model\n", - "except Exception:\n", - " saprot_if_model = SaProtIFModel(**config)\n", - " tokenizer = saprot_if_model.tokenizer\n", - " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - " saprot_if_model.to(device)\n", - "\n", - "################################################################################\n", - "############################### Predict ########################################\n", - "################################################################################\n", - "\n", - "pred_aa_seqs = saprot_if_model.predict(masked_aa_seq, masked_struc_seq, method=method, num_samples=num_samples)\n", - "\n", - "clear_output(wait=True)\n", - "print('='*100)\n", - "print(Fore.BLUE+\"Outputs:\"+Style.RESET_ALL)\n", - "save_path = f\"{root_dir}/SaprotHub/output/{save_name}.fasta\"\n", - "with open(save_path, \"w\") as w:\n", - " for i, aa_seq in enumerate(pred_aa_seqs):\n", - " print(aa_seq)\n", - " w.write(f\">predicted_seq_{i}\\n{aa_seq}\\n\\n\")\n", - "\n", - "file_download(save_path)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "idDDKW2pl2Gf" - }, - "outputs": [], - "source": [ - "#@title **3.3.3: Predict the structure of generated sequence**\n", - "\n", - "#@markdown **Warning: Please make sure you have enough RAM to run this cell. (Do not use Colab's T4 or local server with less than 64G RAM).**\n", - "#@markdown\n", - "#@markdown Otherwise, it will cause an out-of-memory error, and you will have to restart the notebook. We recommend you connect to a runtime\n", - "#@markdown with more RAM to run the cell properly.\n", - "\n", - "#@markdown Click the run button to predict the structure of generated sequence using ESMFold\n", - "\n", - "protein_sequence = \"QQVGSQLDLQEESVEYQIFPTQTHQNDTKNVKERLESILERINSIFIPYSQDYVWQEKELSFMISLGLQQGRPHLMGSTHFGDNIDDEWFLVNLLKQLSQVFPQLTAKISDSDGEFLLIEAADLMPKWLIPEGIENRVFVYNGNLMIIPFLLGRYSLIHYQLSKPSLDQAIDLLRNFPEETRASRDQQKLIHNRINGIIKSFLAGTHKAYCYIPRPIATLLKRKPSLVSHAVETFYYRDPIDVKNCRNMEDFTNAERLRVDVRFTRVLYAQLVSQSFNPPKQMGIEAPDPEDKEFKRELLGMKLTCGFAMMAANLLPSTVDPSLNGWAYLEQFKRFRENVEKGNATAKISEPDDQLELISAVRKFLRYIVEDHIDASILKSLLVVELHRQKQMLPESEEAIRKIKKTLLERWNPGWQMSEEYREKTVGQVENGGDSSCEALKSDSKRADLADLDMGRVQDLSRFIDKESRPLERSKISDLQPEVVMGMEQEEDAAAAVSKVYKGGPYLVPIADLKERPEAVHPKATQVVQGELLLISAEDQESKTSNRRVRFGQHGQSQDQAAPMLVGCDRMTALDSIVPEKEEDKVKKGLGYIHLEKSTNSLLTHAIYKIQGSVSHVAARLADRGIDVTSDNVPIKPQTMEEG\" # @param {type:\"string\"}\n", - "save_name = \"predicted_structure\" # @param {type:\"string\"}\n", - "\n", - "#@markdown Visualization settings\n", - "color = \"lDDT\" #@param [\"chain\", \"lDDT\", \"rainbow\"]\n", - "show_sidechains = False #@param {type:\"boolean\"}\n", - "show_mainchains = False #@param {type:\"boolean\"}\n", - "\n", - "\n", - "################################################################################\n", - "############################### LOAD ESMFOLD ################################\n", - "################################################################################\n", - "try:\n", - " esmfold\n", - "except Exception:\n", - " tokenizer = AutoTokenizer.from_pretrained(\"facebook/esmfold_v1\")\n", - " esmfold = EsmForProteinFolding.from_pretrained(\"facebook/esmfold_v1\")\n", - " esmfold.esm = esmfold.esm.half()\n", - " esmfold.trunk.set_chunk_size(64)\n", - "\n", - " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - " esmfold.to(device)\n", - "\n", - "################################################################################\n", - "################################## PREDICT ###################################\n", - "################################################################################\n", - "tokenized_input = tokenizer(\n", - " [protein_sequence],\n", - " return_tensors=\"pt\",\n", - " add_special_tokens=False,\n", - " )['input_ids']\n", - "\n", - "tokenized_input = tokenized_input.to(esmfold.device)\n", - "with torch.no_grad():\n", - " output = esmfold(tokenized_input)\n", - "\n", - "################################################################################\n", - "#################################### SAVE ####################################\n", - "################################################################################\n", - "save_path = f\"{root_dir}/SaprotHub/output/{save_name}.pdb\"\n", - "pdb = convert_outputs_to_pdb(output)\n", - "with open(save_path, \"w\") as f:\n", - " f.write(\"\".join(pdb))\n", - "\n", - "################################################################################\n", - "################################# VISUALIZE ##################################\n", - "################################################################################\n", - "show_pdb(save_path, show_sidechains, show_mainchains, color).show()\n", - "if color == \"lDDT\":\n", - " plot_plddt_legend().show()\n", - "\n", - "print(\"Predicted structure\")\n", - "file_download(save_path)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "PecOgaEEZw4C" - }, - "outputs": [], - "source": [ - "#@title **3.3.4: Align proteins using TMalign**\n", - "\n", - "#@markdown You can find the **uploaded proteins** from /content/SaprotHub/structures (if you connect to your local server, then the path is /SaprotHub/structures).\n", - "\n", - "#@markdown You can find the **predicted proteins** from /content/SaprotHub/output (if you connect to your local server, then the path is /SaprotHub/output).\n", - "\n", - "#@markdown Right click the **pdb file** to copy the path and then paste it into the box:\n", - "pdb_path_1 = \"/SaprotHub/structures/1qsf.pdb\" # @param {type:\"string\"}\n", - "pdb_path_2 = \"/SaprotHub/output/predicted_structure.pdb\" # @param {type:\"string\"}\n", - "\n", - "pdb_path_1 = f\"{root_dir}{pdb_path_1}\"\n", - "pdb_path_2 = f\"{root_dir}{pdb_path_2}\"\n", - "\n", - "assert os.path.exists(pdb_path_1) and os.path.exists(pdb_path_2), \"Input proteins do not exist!\"\n", - "\n", - "cmd = f\"{root_dir}/SaprotHub/bin/TMalign {pdb_path_1} {pdb_path_2}\"\n", - "print(os.popen(cmd).read())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "t8i9p_Ff4232" - }, - "source": [ - "## **3.4 Extract Protein Embedding**\n", - "\n", - "The shape of extracted embedding is `[N, D]`, where `N` is the number of sequences and `D` is the hidden dimension." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "aA7QozqDh8Nj" - }, - "outputs": [], - "source": [ - "################################################################################\n", - "########################## MODEL ###############################################\n", - "################################################################################\n", - "use_model_from = \"Official pretrained SaProt (35M)\" # @param [\"Official pretrained SaProt (35M)\",\"Official pretrained SaProt (650M)\",\"Trained by yourself on ColabSaprot\",\"Shared by peers on SaprotHub\",\"Saved in your local computer\"]\n", - "if use_model_from == \"Multi-models on SaprotHub\":\n", - " multi_lora = True\n", - "else:\n", - " multi_lora = False\n", - "\n", - "adapter_input = select_adapter_from(None, use_model_from)\n", - "\n", - "################################################################################\n", - "########################## DATASET #############################################\n", - "################################################################################\n", - "data_type = \"Multiple SA Sequences\" # @param [\"Single SA Sequence\", \"Single UniProt ID\", \"Single PDB/CIF Structure\", \"Multiple SA Sequences\", \"Multiple UniProt IDs\", \"Multiple PDB/CIF Structures\"]\n", - "# check_task_type_and_data_type(original_task_type, data_type)\n", - "\n", - "mode = \"Multiple Sequences\" if (data_type in data_type_list_multiple) else \"Single Sequence\"\n", - "\n", - "raw_data = input_raw_data_by_data_type(data_type)\n", - "\n", - "################################################################################\n", - "########################## EXTRACT #############################################\n", - "################################################################################\n", - "def extract(button):\n", - " button.disabled = True\n", - " button.description = 'Extracting...'\n", - " button.button_style = ''\n", - "\n", - " print('\\n')\n", - " print('='*100)\n", - "\n", - " ##############################################################################\n", - " ################################# MODEL ###################################\n", - " ##############################################################################\n", - " if multi_lora:\n", - " if use_model_from == \"Multi-models on ColabSaprot\":\n", - " config_list = [EasyDict({'lora_config_path': ADAPTER_HOME / task_type / lora_config_path}) for lora_config_path in list(adapter_input.value)]\n", - " elif use_model_from == \"Multi-models on SaprotHub\":\n", - " #1. get adapter_list\n", - " repo_id_list = adapter_input.value.replace(\" \", \"\").split(',')\n", - " #2. download adapters\n", - " for repo_id in repo_id_list:\n", - " snapshot_download(repo_id=repo_id, repo_type=\"model\", local_dir=ADAPTER_HOME / task_type / repo_id)\n", - " config_list = [EasyDict({'lora_config_path': ADAPTER_HOME / task_type / repo_id}) for repo_id in repo_id_list]\n", - "\n", - " assert len(config_list) > 0, \"Please select your models from the dropdown menu on the output of 3.1!\"\n", - " base_model = get_base_model(ADAPTER_HOME / task_type / config_list[0].lora_config_path)\n", - "\n", - " required_training_data_type_list = []\n", - " for lora_config in config_list:\n", - " required_training_data_type_list.append(check_training_data_type(lora_config.lora_config_path, data_type))\n", - " assert len(set(required_training_data_type_list)) == 1, f\"Error: The input data types of these models are not identical: {required_training_data_type_list}\"\n", - " required_training_data_type = required_training_data_type_list[0]\n", - "\n", - " lora_kwargs = EasyDict({\n", - " \"is_trainable\": False,\n", - " \"num_lora\": len(config_list),\n", - " \"config_list\": config_list\n", - " })\n", - "\n", - " elif use_model_from == \"Official pretrained SaProt (35M)\":\n", - " base_model = \"westlake-repl/SaProt_35M_AF2\"\n", - " lora_kwargs = None\n", - "\n", - " elif use_model_from == \"Official pretrained SaProt (650M)\":\n", - " base_model = \"westlake-repl/SaProt_650M_AF2\"\n", - " lora_kwargs = None\n", - "\n", - " else:\n", - " adapter_path = ADAPTER_HOME / adapter_input.value\n", - "\n", - " if use_model_from == \"Shared by peers on SaprotHub\":\n", - " snapshot_download(repo_id=adapter_input.value, repo_type=\"model\", local_dir=adapter_path)\n", - "\n", - " base_model = get_base_model(adapter_path)\n", - " required_training_data_type = check_training_data_type(adapter_path, data_type)\n", - " lora_kwargs = {\n", - " \"is_trainable\": False,\n", - " \"num_lora\": 1,\n", - " \"config_list\": [{\"lora_config_path\": adapter_path}]\n", - " }\n", - "\n", - " ##############################################################################\n", - " ################################# DATASET ###################################\n", - " ##############################################################################\n", - " if data_type in data_type_list_multiple:\n", - " csv_dataset_path = get_SA_sequence_by_data_type(data_type, raw_data)\n", - " df = read_csv_dataset(csv_dataset_path)\n", - " else:\n", - " single_sa_seq = get_SA_sequence_by_data_type(data_type, raw_data)\n", - " # if task_type in [\"pair_classification\", \"pair_regression\"]:\n", - " # df = pd.DataFrame({\n", - " # 'sequence_1': [single_sa_seq[0]],\n", - " # 'sequence_2': [single_sa_seq[1]]\n", - " # })\n", - " # else:\n", - " df = pd.DataFrame({\n", - " 'sequence': [single_sa_seq]\n", - " })\n", - "\n", - " # if (required_training_data_type == \"AA\") and (\"AA\" not in data_type):\n", - " # if 'sequence' in df.columns:\n", - " # df['sequence'] = df['sequence'].apply(mask_struc_token)\n", - " # elif 'sequence_1' in df.columns and 'sequence_2' in df.columns:\n", - " # df['sequence_1'] = df['sequence_1'].apply(mask_struc_token)\n", - " # df['sequence_2'] = df['sequence_2'].apply(mask_struc_token)\n", - "\n", - " ################################################################################\n", - " ##################################### CONFIG ###################################\n", - " ################################################################################\n", - " from saprot.config.config_dict import Default_config\n", - " config = copy.deepcopy(Default_config)\n", - "\n", - "\n", - " # task\n", - " if use_model_from in [\"Official pretrained SaProt (35M)\", \"Official pretrained SaProt (650M)\"]:\n", - " num_labels, task_type = 1, 'classification'\n", - " else:\n", - " num_labels, task_type = get_num_labels_and_task_type_by_adapter(lora_kwargs[\"config_list\"][0][\"lora_config_path\"])\n", - "\n", - " config.model.kwargs.num_labels = num_labels\n", - " # base model\n", - " config.model.model_py_path = model_type_dict[task_type]\n", - " config.model.kwargs.config_path = base_model\n", - " # lora\n", - " config.model.kwargs.lora_kwargs = lora_kwargs\n", - "\n", - " ################################################################################\n", - " ################################### LOAD MODEL ##################################\n", - " ################################################################################\n", - " model = my_load_model(config.model)\n", - " tokenizer = EsmTokenizer.from_pretrained(config.model.kwargs.config_path)\n", - " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - " model.to(device)\n", - "\n", - " ################################################################################\n", - " ################################### INFO #######################################\n", - " ################################################################################\n", - " # clear_output(wait=True)\n", - " print('\\n')\n", - " print('='*100)\n", - "\n", - " # print(Fore.BLUE+f\"Task Type: {original_task_type}\"+Style.RESET_ALL)\n", - "\n", - " print(Fore.BLUE+f\"Model ({use_model_from}):\"+Style.RESET_ALL)\n", - " if use_model_from in [\"Official pretrained SaProt (35M)\", \"Official pretrained SaProt (650M)\"]:\n", - " print(Fore.BLUE+f\" Base Model: {base_model}\"+Style.RESET_ALL)\n", - " elif multi_lora:\n", - " print(Fore.BLUE+f\" Base Model: {base_model}\"+Style.RESET_ALL)\n", - " print(Fore.BLUE+f\" Adapter:\"+Style.RESET_ALL)\n", - " for lora_config in lora_kwargs.config_list:\n", - " print(Fore.BLUE+f\" {lora_config.lora_config_path}\"+Style.RESET_ALL)\n", - " else:\n", - " print(Fore.BLUE+f\" Base Model: {base_model}\"+Style.RESET_ALL)\n", - " print(Fore.BLUE+f\" Adapter: {adapter_path}\"+Style.RESET_ALL)\n", - "\n", - " print(Fore.BLUE+f'Dataset ({data_type}):' +Style.RESET_ALL)\n", - " if mode == \"Multiple Sequences\":\n", - " print(Fore.BLUE+f\" CSV Dataset Path: {csv_dataset_path}\"+Style.RESET_ALL)\n", - " else:\n", - " # if \"A pair of\" in data_type:\n", - " # print(Fore.BLUE+f\" Sequence 1: {single_sa_seq[0]}\"+Style.RESET_ALL)\n", - " # print(Fore.BLUE+f\" Sequence 2: {single_sa_seq[1]}\"+Style.RESET_ALL)\n", - " # else:\n", - " print(Fore.BLUE+f\" Sequence: {single_sa_seq}\"+Style.RESET_ALL)\n", - "\n", - " ################################################################################\n", - " ################################### INFERENCE ##################################\n", - " ################################################################################\n", - " print('\\n')\n", - " print('='*100)\n", - " print(Fore.BLUE+\"Predicting...\"+Style.RESET_ALL)\n", - "\n", - " seqs = df['sequence']\n", - " embedding_list = []\n", - " with torch.no_grad():\n", - " for seq in tqdm(seqs, total=len(seqs)):\n", - " embedding = model.get_hidden_states_from_seqs(seqs, reduction='mean')\n", - " embedding_list.append(embedding[0])\n", - " embeddings = torch.stack(embedding_list)\n", - " # print(embeddings.shape)\n", - "\n", - " print()\n", - " print('='*100)\n", - " print(Fore.BLUE+\"Prediction Result:\"+Style.RESET_ALL)\n", - "\n", - " timestamp = str(datetime.now().strftime(\"%Y%m%d%H%M%S\"))\n", - " embeddings_path = OUTPUT_HOME / f'embeddings_{timestamp}.pt'\n", - " torch.save(embeddings, embeddings_path)\n", - " print(Fore.BLUE+f\"The extracted embeddings is saved to {embeddings_path}.\"+Style.RESET_ALL)\n", - " file_download(embeddings_path)\n", - "\n", - "################################################################################\n", - "#################################### BUTTON #################################\n", - "################################################################################\n", - "button_extract = ipywidgets.Button(\n", - " description='Extract Protein Embeddings',\n", - " disabled=False,\n", - " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", - " tooltip='Apply',\n", - " # icon='check' # (FontAwesome names without the `fa-` prefix)\n", - " )\n", - "button_extract.on_click(extract)\n", - "button_extract.layout.width = '500px'\n", - "display(button_extract)\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LIIoTQgJBI7o" - }, - "source": [ - "# **4: (Optional) Data Preparation**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "collapsed": true, - "id": "BgsBSLcmBI7o" - }, - "outputs": [], - "source": [ - "# @title **4.1: Get Structure-Aware Sequence** \n", - "\n", - "#@markdown AA Sequence, UniProt ID, PDB/CIF file -> SA Sequence\n", - "\n", - "################################################################################\n", - "################################ input #########################################\n", - "################################################################################\n", - "\n", - "data_type = \"Multiple PDB/CIF Structures\" # @param [\"Single AA Sequence\", \"Single UniProt ID\", \"Single PDB/CIF Structure\", \"Multiple AA Sequences\", \"Multiple UniProt IDs\", \"Multiple PDB/CIF Structures\"]\n", - "raw_data = input_raw_data_by_data_type(data_type)\n", - "\n", - "################################################################################\n", - "############################### output #########################################\n", - "################################################################################\n", - "\n", - "if data_type in [\"Single AA Sequence\", \"Single UniProt ID\", \"Single PDB/CIF Structure\"]:\n", - " def apply(button):\n", - " button.disabled = True\n", - " button.description = 'Clicked'\n", - " button.button_style = ''\n", - " sa_seq = get_SA_sequence_by_data_type(data_type, raw_data)\n", - "\n", - " print(\"=\"*100)\n", - " print(f\"Amino Acid Sequence: {sa_seq[0::2]}\")\n", - " print(f\"Structure Sequence: {sa_seq[1::2]}\")\n", - " print(\"=\"*100)\n", - " print(\"Please note that structure tokens with a plDDT score lower than 70% are denoted as \\\"#\\\"\")\n", - " print(Fore.BLUE + \"The Structure-Aware Sequence is here, double click to select and copy it:\" + Style.RESET_ALL)\n", - " print(sa_seq)\n", - "\n", - " button_apply = ipywidgets.Button(\n", - " description='Apply',\n", - " disabled=False,\n", - " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", - " tooltip='Apply',\n", - " icon='check' # (FontAwesome names without the `fa-` prefix)\n", - " )\n", - " button_apply.on_click(apply)\n", - " button_apply.layout.width = '500px'\n", - " display(button_apply)\n", - "else:\n", - " csv_dataset = get_SA_sequence_by_data_type(data_type, raw_data)\n", - " print(Fore.BLUE + \"\\n\\nThe Structure-Aware Sequences are saved in a .csv file here:\" + Style.RESET_ALL)\n", - " print(csv_dataset)\n", - " file_download(csv_dataset)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "IDkm_OeABI7o" - }, - "outputs": [], - "source": [ - "#@title **4.2: Convert `.fa/.fasta` file to `.csv` file in the data format of \"Multiple AA Sequences\"**\n", - "\n", - "\n", - "#@markdown `.fa/.fasta` -> Multiple AA Sequences `.csv` \n", - "\n", - "from Bio import SeqIO\n", - "import numpy as np\n", - "\n", - "aa_seq_dict = { 'sequence': [],\n", - " # \"label\": [],\n", - " # \"stage\":[]\n", - " }\n", - "\n", - "fa_file_path = upload_file(UPLOAD_FILE_HOME)\n", - "assert Path(fa_file_path).name.split('.')[1] in ['fa', 'fasta'], \"Please upload a .fa or .fasta file.\"\n", - "with fa_file_path.open(\"r\") as fa:\n", - " for record in tqdm(SeqIO.parse(fa, 'fasta'), leave=True):\n", - " aa_seq_dict['sequence'].append(str(record.seq))\n", - "\n", - "fa_df = pd.DataFrame(aa_seq_dict)\n", - "print(fa_df[5:])\n", - "\n", - "csv_file_path = UPLOAD_FILE_HOME / f'{fa_file_path.stem}.csv'\n", - "fa_df.to_csv(csv_file_path, index=None)\n", - "# files.download(csv_file_path)\n", - "file_download(csv_file_path)\n", - "\n", - "################################################################################\n", - "############################ .fa 2 .csv and split ##############################\n", - "################################################################################\n", - "\n", - "# automatically_split_dataset = False # @param {type:\"boolean\"}\n", - "# split = ['train', 'valid', 'test']\n", - "\n", - "# aa_seq_dict = { 'sequence': [],\n", - "# \"label\": [],\n", - "# \"stage\":[]}\n", - "\n", - "\n", - "\n", - "# if automatically_split_dataset:\n", - "\n", - "# fa_file_path = upload_file(UPLOAD_FILE_HOME)\n", - "# label = fa_file_path.stem\n", - "\n", - "# with fa_file_path.open(\"r\") as fa:\n", - "# for record in tqdm(SeqIO.parse(fa, 'fasta'), leave=True):\n", - "# aa_seq_dict['sequence'].append(str(record.seq))\n", - "# aa_seq_dict[\"label\"].append(label)\n", - "# weights = [0.8, 0.1, 0.1]\n", - "# aa_seq_dict[\"stage\"] = np.random.choice(split, size=len(aa_seq_dict['sequence']), p=weights).tolist()\n", - "\n", - "# else:\n", - "# for i in range(3):\n", - "# print(Fore.BLUE+f\"Please upload a .fa file as your {split[i]} dataset\")\n", - "# fa_file_path = upload_file(UPLOAD_FILE_HOME)\n", - "# label = fa_file_path.stem\n", - "\n", - "# with fa_file_path.open(\"r\") as fa:\n", - "# for record in tqdm(SeqIO.parse(fa, 'fasta')):\n", - "# aa_seq_dict['sequence'].append(str(record.seq))\n", - "# aa_seq_dict[\"label\"].append(label)\n", - "# aa_seq_dict[\"stage\"].append(split[i])\n", - "\n", - "# print()\n", - "# print(\"=\"*100)\n", - "\n", - "# fa_df = pd.DataFrame(aa_seq_dict)\n", - "# timestamp = datetime.now().strftime(\"%y%m%d%H%M%S\")\n", - "# fa_df.to_csv(f'/content/SaprotHub/upload_files/{timestamp}.csv', index=None)\n", - "# files.download(f'/content/SaprotHub/upload_files/{timestamp}.csv')\n", - "# print(fa_df[5:])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "xtadHW9vBI7o" - }, - "outputs": [], - "source": [ - "#@title **4.3: Dataset Split** \n", - "\n", - "#@markdown Randomly split your .csv dataset by adding a `stage` column, assigning the values `train`, `valid`, and `test` according to the specified split ratio (default is `0.8:0.1:0.1`). Ensure that the sum of the ratios equals 1 and that each ratio is greater than 0.\n", - "\n", - "#@markdown Your .csv dataset should contain `sequence` and `label` columns, making it suitable for training after splitting. For the specific format, please refer to [tutorial](https://github.com/westlake-repl/SaprotHub/wiki/2.1:-Train-your-model#dataset-format).\n", - "\n", - "# @markdown Please click the run button to upload your .csv dataset\n", - "\n", - "split = ['train', 'valid', 'test']\n", - "# split_ratio = [0.8, 0.1, 0.1]\n", - "split_ratio = 0.8, 0.1, 0.1 # @param {\"type\":\"raw\",\"placeholder\":\"0.8, 0.1, 0.1\"}\n", - "\n", - "if any(w == 0 or w == 1 for w in split_ratio):\n", - " raise ValueError(\"One or more proportions for train, valid, and test are either 0 or 1. Please ensure all values are between 0 and 1.\")\n", - "elif sum(split_ratio) != 1:\n", - " raise ValueError(\"The sum of the proportions for train, valid, and test is not equal to 1. Please check the values.\")\n", - "else:\n", - " print(f\"The split ratio for train, valid, and test is {split_ratio[0]}:{split_ratio[1]}:{split_ratio[2]}.\")\n", - "\n", - "print('='*100)\n", - "print(\"Upload your .csv dataset:\")\n", - "csv_dataset_path = upload_file(UPLOAD_FILE_HOME)\n", - "dataset_df = read_csv_dataset(csv_dataset_path)\n", - "\n", - "while ('stage' not in dataset_df.columns) or (dataset_df[\"stage\"].nunique()<3):\n", - " dataset_df[\"stage\"] = np.random.choice(split, size=len(dataset_df), p=split_ratio).tolist()\n", - "\n", - "dataset_df.to_csv(csv_dataset_path, index=None)\n", - "file_download(csv_dataset_path)\n" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "T4", - "provenance": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file diff --git a/notebooks/protein_language_modeling.ipynb b/notebooks/protein_language_modeling.ipynb deleted file mode 100644 index ccacd47..0000000 --- a/notebooks/protein_language_modeling.ipynb +++ /dev/null @@ -1,2554 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "af5d6f2e", - "metadata": { - "id": "af5d6f2e" - }, - "source": [ - "If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers as well as some other libraries. Uncomment the following cell and run it.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4c5bf8d4", - "metadata": { - "id": "4c5bf8d4" - }, - "outputs": [], - "source": [ - "#! pip install transformers evaluate datasets requests pandas sklearn" - ] - }, - { - "cell_type": "markdown", - "id": "76e71a3f", - "metadata": { - "id": "76e71a3f" - }, - "source": [ - "If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.\n", - "\n", - "To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.\n", - "\n", - "First you have to store your authentication token from the Hugging Face website (sign up here if you haven't already!) then execute the following cell and input your username and password:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "25b8526a", - "metadata": { - "id": "25b8526a" - }, - "outputs": [], - "source": [ - "from huggingface_hub import notebook_login\n", - "\n", - "notebook_login()" - ] - }, - { - "cell_type": "markdown", - "id": "ab8b2712", - "metadata": { - "id": "ab8b2712" - }, - "source": [ - "Then you need to install Git-LFS. Uncomment the following instructions:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "19e8f77c", - "metadata": { - "id": "19e8f77c" - }, - "outputs": [], - "source": [ - "# !apt install git-lfs" - ] - }, - { - "cell_type": "markdown", - "id": "80dbad4e", - "metadata": { - "id": "80dbad4e" - }, - "source": [ - "We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d107b8d9", - "metadata": { - "id": "d107b8d9" - }, - "outputs": [], - "source": [ - "from transformers.utils import send_example_telemetry\n", - "\n", - "send_example_telemetry(\"protein_language_modeling_notebook\", framework=\"pytorch\")" - ] - }, - { - "cell_type": "markdown", - "id": "5c0749e1", - "metadata": { - "id": "5c0749e1" - }, - "source": [ - "# Fine-Tuning Protein Language Models\n" - ] - }, - { - "cell_type": "markdown", - "id": "1d81db83", - "metadata": { - "id": "1d81db83" - }, - "source": [ - "In this notebook, we're going to do some transfer learning to fine-tune some large, pre-trained protein language models on tasks of interest. If that sentence feels a bit intimidating to you, don't panic - there's [a blog post](https://huggingface.co/blog/deep-learning-with-proteins) that explains the concepts here in much more detail.\n", - "\n", - "The specific model we're going to use is ESM-2, which is the state-of-the-art protein language model at the time of writing (November 2022). The citation for this model is [Lin et al, 2022](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1).\n", - "\n", - "There are several ESM-2 checkpoints with differing model sizes. Larger models will generally have better accuracy, but they require more GPU memory and will take much longer to train. The available ESM-2 checkpoints (at time of writing) are:\n", - "\n", - "| Checkpoint name | Num layers | Num parameters |\n", - "| --------------------- | ---------- | -------------- |\n", - "| `esm2_t48_15B_UR50D` | 48 | 15B |\n", - "| `esm2_t36_3B_UR50D` | 36 | 3B |\n", - "| `esm2_t33_650M_UR50D` | 33 | 650M |\n", - "| `esm2_t30_150M_UR50D` | 30 | 150M |\n", - "| `esm2_t12_35M_UR50D` | 12 | 35M |\n", - "| `esm2_t6_8M_UR50D` | 6 | 8M |\n", - "\n", - "Note that the larger checkpoints may be very difficult to train without a large cloud GPU like an A100 or H100, and the largest 15B parameter checkpoint will probably be impossible to train on **any** single GPU! Also, note that memory usage for attention during training will scale as `O(batch_size * num_layers * seq_len^2)`, so larger models on long sequences will use quite a lot of memory! We will use the `esm2_t12_35M_UR50D` checkpoint for this notebook, which should train on any Colab instance or modern GPU.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "32e605a2", - "metadata": { - "id": "32e605a2" - }, - "outputs": [], - "source": [ - "model_checkpoint = \"facebook/esm2_t12_35M_UR50D\"" - ] - }, - { - "cell_type": "markdown", - "id": "a8e6ac19", - "metadata": { - "id": "a8e6ac19" - }, - "source": [ - "# Sequence classification\n" - ] - }, - { - "cell_type": "markdown", - "id": "c3eb400c", - "metadata": { - "id": "c3eb400c" - }, - "source": [ - "One of the most common tasks you can perform with a language model is **sequence classification**. In sequence classification, we classify an entire protein into a category, from a list of two or more possibilities. There's no limit on the number of categories you can use, or the specific problem you choose, as long as it's something the model could in theory infer from the raw protein sequence. To keep things simple for this example, though, let's try classifying proteins by their cellular localization - given their sequence, can we predict if they're going to be found in the cytosol (the fluid inside the cell) or embedded in the cell membrane?\n" - ] - }, - { - "cell_type": "markdown", - "id": "c5bc122f", - "metadata": { - "id": "c5bc122f" - }, - "source": [ - "## Data preparation\n" - ] - }, - { - "cell_type": "markdown", - "id": "4c91d394", - "metadata": { - "id": "4c91d394" - }, - "source": [ - "In this section, we're going to gather some training data from UniProt. Our goal is to create a pair of lists: `sequences` and `labels`. `sequences` will be a list of protein sequences, which will just be strings like \"MNKL...\", where each letter represents a single amino acid in the complete protein. `labels` will be a list of the category for each sequence. The categories will just be integers, with 0 representing the first category, 1 representing the second and so on. In other words, if `sequences[i]` is a protein sequence then `labels[i]` should be its corresponding category. These will form the **training data** we're going to use to teach the model the task we want it to do.\n", - "\n", - "If you're adapting this notebook for your own use, this will probably be the main section you want to change! You can do whatever you want here, as long as you create those two lists by the end of it. If you want to follow along with this example, though, first we'll need to `import requests` and set up our query to UniProt.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c718ffbc", - "metadata": { - "id": "c718ffbc" - }, - "outputs": [], - "source": [ - "import requests\n", - "\n", - "query_url = \"https://rest.uniprot.org/uniprotkb/stream?compressed=true&fields=accession%2Csequence%2Ccc_subcellular_location&format=tsv&query=%28%28organism_id%3A9606%29%20AND%20%28reviewed%3Atrue%29%20AND%20%28length%3A%5B80%20TO%20500%5D%29%29\"" - ] - }, - { - "cell_type": "markdown", - "id": "3d2edc14", - "metadata": { - "id": "3d2edc14" - }, - "source": [ - "This query URL might seem mysterious, but it isn't! To get it, we searched for `(organism_id:9606) AND (reviewed:true) AND (length:[80 TO 500])` on UniProt to get a list of reasonably-sized human proteins,\n", - "then selected 'Download', and set the format to TSV and the columns to `Sequence` and `Subcellular location [CC]`, since those contain the data we care about for this task.\n", - "\n", - "Once that's done, selecting `Generate URL for API` gives you a URL you can pass to Requests. Alternatively, if you're not on Colab you can just download the data through the web interface and open the file locally.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dd03ef98", - "metadata": { - "id": "dd03ef98" - }, - "outputs": [], - "source": [ - "uniprot_request = requests.get(query_url)" - ] - }, - { - "cell_type": "markdown", - "id": "b7217b77", - "metadata": { - "id": "b7217b77" - }, - "source": [ - "To get this data into Pandas, we use a `BytesIO` object, which Pandas will treat like a file. If you downloaded the data as a file you can skip this bit and just pass the filepath directly to `read_csv`.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f2c05017", - "metadata": { - "id": "f2c05017", - "outputId": "5883838a-a64d-4709-f56f-fda7c356e886" - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
EntrySequenceSubcellular location [CC]
0A0A0K2S4Q6MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E...
1A0A5B9DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV...SUBCELLULAR LOCATION: Cell membrane {ECO:00003...
2A0AVI4MDSPEVTFTLAYLVFAVCFVFTPNEFHAAGLTVQNLLSGWLGSEDA...SUBCELLULAR LOCATION: Endoplasmic reticulum me...
3A0JLT2MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP...SUBCELLULAR LOCATION: Nucleus {ECO:0000305}.
4A0M8Q6GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG...SUBCELLULAR LOCATION: Secreted {ECO:0000303|Pu...
............
11977Q9NZ38MAFPGQSDTKMQWPEVPALPLLSSLCMAMVRKSSALGKEVGRRSEG...NaN
11978Q9UFV3MAETYRRSRQHEQLPGQRHMDLLTGYSKLIQSRLKLLLHLGSQPPV...NaN
11979Q9Y6C7MAHHSLNTFYIWHNNVLHTHLVFFLPHLLNQPFSRGSFLIWLLLCW...NaN
11980X6R8D5MGRKEHESPSQPHMCGWEDSQKPSVPSHGPKTPSCKGVKAPHSSRP...NaN
11981X6R8R1MGVVLSPHPAPSRREPLAPLAPGTRPGWSPAVSGSSRSALRPSTAG...NaN
\n", - "

11982 rows × 3 columns

\n", - "
" - ], - "text/plain": [ - " Entry Sequence \\\n", - "0 A0A0K2S4Q6 MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE... \n", - "1 A0A5B9 DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV... \n", - "2 A0AVI4 MDSPEVTFTLAYLVFAVCFVFTPNEFHAAGLTVQNLLSGWLGSEDA... \n", - "3 A0JLT2 MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP... \n", - "4 A0M8Q6 GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG... \n", - "... ... ... \n", - "11977 Q9NZ38 MAFPGQSDTKMQWPEVPALPLLSSLCMAMVRKSSALGKEVGRRSEG... \n", - "11978 Q9UFV3 MAETYRRSRQHEQLPGQRHMDLLTGYSKLIQSRLKLLLHLGSQPPV... \n", - "11979 Q9Y6C7 MAHHSLNTFYIWHNNVLHTHLVFFLPHLLNQPFSRGSFLIWLLLCW... \n", - "11980 X6R8D5 MGRKEHESPSQPHMCGWEDSQKPSVPSHGPKTPSCKGVKAPHSSRP... \n", - "11981 X6R8R1 MGVVLSPHPAPSRREPLAPLAPGTRPGWSPAVSGSSRSALRPSTAG... \n", - "\n", - " Subcellular location [CC] \n", - "0 SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E... \n", - "1 SUBCELLULAR LOCATION: Cell membrane {ECO:00003... \n", - "2 SUBCELLULAR LOCATION: Endoplasmic reticulum me... \n", - "3 SUBCELLULAR LOCATION: Nucleus {ECO:0000305}. \n", - "4 SUBCELLULAR LOCATION: Secreted {ECO:0000303|Pu... \n", - "... ... \n", - "11977 NaN \n", - "11978 NaN \n", - "11979 NaN \n", - "11980 NaN \n", - "11981 NaN \n", - "\n", - "[11982 rows x 3 columns]" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from io import BytesIO\n", - "import pandas\n", - "\n", - "bio = BytesIO(uniprot_request.content)\n", - "\n", - "df = pandas.read_csv(bio, compression=\"gzip\", sep=\"\\t\")\n", - "df" - ] - }, - { - "cell_type": "markdown", - "id": "0bcdf34b", - "metadata": { - "id": "0bcdf34b" - }, - "source": [ - "Nice! Now we have some proteins and their subcellular locations. Let's start filtering this down. First, let's ditch the columns without subcellular location information.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "31d87663", - "metadata": { - "id": "31d87663" - }, - "outputs": [], - "source": [ - "df = df.dropna() # Drop proteins with missing columns" - ] - }, - { - "cell_type": "markdown", - "id": "10d1af5c", - "metadata": { - "id": "10d1af5c" - }, - "source": [ - "Now we'll make one dataframe of proteins that contain `cytosol` or `cytoplasm` in their subcellular localization column, and a second that mentions the `membrane` or `cell membrane`. To ensure we don't get overlap, we ensure each dataframe only contains proteins that don't match the other search term.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c831bb16", - "metadata": { - "id": "c831bb16" - }, - "outputs": [], - "source": [ - "cytosolic = df[\"Subcellular location [CC]\"].str.contains(\"Cytosol\") | df[\n", - " \"Subcellular location [CC]\"\n", - "].str.contains(\"Cytoplasm\")\n", - "membrane = df[\"Subcellular location [CC]\"].str.contains(\"Membrane\") | df[\n", - " \"Subcellular location [CC]\"\n", - "].str.contains(\"Cell membrane\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f41139a2", - "metadata": { - "id": "f41139a2", - "outputId": "8cc6dd3b-a805-40fe-afe9-f00fc836cd99" - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
EntrySequenceSubcellular location [CC]
10A1E959MKIIILLGFLGATLSAPLIPQRLMSASNSNELLLNLNNGQLLPLQL...SUBCELLULAR LOCATION: Secreted {ECO:0000250|Un...
15A1XBS5MMRRTLENRNAQTKQLQTAVSNVEKHFGELCQIFAAYVRKTARLRD...SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P...
19A2RU49MSSGNYQQSEALSKPTFSEEQASALVESVFGLKVSKVRPLPSYDDQ...SUBCELLULAR LOCATION: Cytoplasm {ECO:0000305}.
21A2RUH7MEAATAPEVAAGSKLKVKEASPADAEPPQASPGQGAGSPTPQLLPP...SUBCELLULAR LOCATION: Cytoplasm, myofibril, sa...
22A4D126MEAGPPGSARPAEPGPCLSGQRGADHTASASLQSVAGTEPGRHPQA...SUBCELLULAR LOCATION: Cytoplasm, cytosol {ECO:...
............
11555Q96L03MATLARLQARSSTVGNQYYFRNSVVDPFRKKENDAAVKIQSWFRGC...SUBCELLULAR LOCATION: Cytoplasm {ECO:0000250}.
11597Q9BYD9MNHCQLPVVIDNGSGMIKAGVAGCREPQFIYPNIIGRAKGQSRAAQ...SUBCELLULAR LOCATION: Cytoplasm, cytoskeleton ...
11639Q9NPB0MEQRLAEFRAARKRAGLAAQPPAASQGAQTPGEKAEAAATLKAAPG...SUBCELLULAR LOCATION: Cytoplasmic vesicle memb...
11652Q9NUJ7MGGQVSASNSFSRLHCRNANEDWMSALCPRLWDVPLHHLSIPGSHD...SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P...
11662Q9P2W6MGRTWCGMWRRRRPGRRSAVPRWPHLSSQSGVEPPDRWTGTPGWPS...SUBCELLULAR LOCATION: Cytoplasm.
\n", - "

2495 rows × 3 columns

\n", - "
" - ], - "text/plain": [ - " Entry Sequence \\\n", - "10 A1E959 MKIIILLGFLGATLSAPLIPQRLMSASNSNELLLNLNNGQLLPLQL... \n", - "15 A1XBS5 MMRRTLENRNAQTKQLQTAVSNVEKHFGELCQIFAAYVRKTARLRD... \n", - "19 A2RU49 MSSGNYQQSEALSKPTFSEEQASALVESVFGLKVSKVRPLPSYDDQ... \n", - "21 A2RUH7 MEAATAPEVAAGSKLKVKEASPADAEPPQASPGQGAGSPTPQLLPP... \n", - "22 A4D126 MEAGPPGSARPAEPGPCLSGQRGADHTASASLQSVAGTEPGRHPQA... \n", - "... ... ... \n", - "11555 Q96L03 MATLARLQARSSTVGNQYYFRNSVVDPFRKKENDAAVKIQSWFRGC... \n", - "11597 Q9BYD9 MNHCQLPVVIDNGSGMIKAGVAGCREPQFIYPNIIGRAKGQSRAAQ... \n", - "11639 Q9NPB0 MEQRLAEFRAARKRAGLAAQPPAASQGAQTPGEKAEAAATLKAAPG... \n", - "11652 Q9NUJ7 MGGQVSASNSFSRLHCRNANEDWMSALCPRLWDVPLHHLSIPGSHD... \n", - "11662 Q9P2W6 MGRTWCGMWRRRRPGRRSAVPRWPHLSSQSGVEPPDRWTGTPGWPS... \n", - "\n", - " Subcellular location [CC] \n", - "10 SUBCELLULAR LOCATION: Secreted {ECO:0000250|Un... \n", - "15 SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P... \n", - "19 SUBCELLULAR LOCATION: Cytoplasm {ECO:0000305}. \n", - "21 SUBCELLULAR LOCATION: Cytoplasm, myofibril, sa... \n", - "22 SUBCELLULAR LOCATION: Cytoplasm, cytosol {ECO:... \n", - "... ... \n", - "11555 SUBCELLULAR LOCATION: Cytoplasm {ECO:0000250}. \n", - "11597 SUBCELLULAR LOCATION: Cytoplasm, cytoskeleton ... \n", - "11639 SUBCELLULAR LOCATION: Cytoplasmic vesicle memb... \n", - "11652 SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P... \n", - "11662 SUBCELLULAR LOCATION: Cytoplasm. \n", - "\n", - "[2495 rows x 3 columns]" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cytosolic_df = df[cytosolic & ~membrane]\n", - "cytosolic_df" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "be5c420e", - "metadata": { - "id": "be5c420e", - "outputId": "b8820adb-cb38-4b9b-8cb2-d38ae74024c9" - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
EntrySequenceSubcellular location [CC]
0A0A0K2S4Q6MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E...
1A0A5B9DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV...SUBCELLULAR LOCATION: Cell membrane {ECO:00003...
4A0M8Q6GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG...SUBCELLULAR LOCATION: Secreted {ECO:0000303|Pu...
18A2RU14MAGTVLGVGAGVFILALLWVAVLLLCVLLSRASGAARFSVIFLFFG...SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
35A5X5Y0MEGSWFHRKRFSFYLLLGFLLQGRGVTFTINCSGFGQHGADPTALN...SUBCELLULAR LOCATION: Cell membrane {ECO:00002...
............
11843Q6UWF5MQIQNNLFFCCYTVMSAIFKWLLLYSLPALCFLLGTQESESFHSKA...SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
11917Q8N8V8MLLKVRRASLKPPATPHQGAFRAGNVIGQLIYLLTWSLFTAWLRPP...SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
11958Q96N68MQGQGALKESHIHLPTEQPEASLVLQGQLAESSALGPKGALRPQAQ...SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
11965Q9H0A3MMNNTDFLMLNNPWNKLCLVSMDFCFPLDFVSNLFWIFASKFIIVT...SUBCELLULAR LOCATION: Membrane {ECO:0000255}; ...
11968Q9H354MNKHNLRLVQLASELILIEIIPKLFLSQVTTISHIKREKIPPNHRK...SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
\n", - "

2579 rows × 3 columns

\n", - "
" - ], - "text/plain": [ - " Entry Sequence \\\n", - "0 A0A0K2S4Q6 MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE... \n", - "1 A0A5B9 DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV... \n", - "4 A0M8Q6 GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG... \n", - "18 A2RU14 MAGTVLGVGAGVFILALLWVAVLLLCVLLSRASGAARFSVIFLFFG... \n", - "35 A5X5Y0 MEGSWFHRKRFSFYLLLGFLLQGRGVTFTINCSGFGQHGADPTALN... \n", - "... ... ... \n", - "11843 Q6UWF5 MQIQNNLFFCCYTVMSAIFKWLLLYSLPALCFLLGTQESESFHSKA... \n", - "11917 Q8N8V8 MLLKVRRASLKPPATPHQGAFRAGNVIGQLIYLLTWSLFTAWLRPP... \n", - "11958 Q96N68 MQGQGALKESHIHLPTEQPEASLVLQGQLAESSALGPKGALRPQAQ... \n", - "11965 Q9H0A3 MMNNTDFLMLNNPWNKLCLVSMDFCFPLDFVSNLFWIFASKFIIVT... \n", - "11968 Q9H354 MNKHNLRLVQLASELILIEIIPKLFLSQVTTISHIKREKIPPNHRK... \n", - "\n", - " Subcellular location [CC] \n", - "0 SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E... \n", - "1 SUBCELLULAR LOCATION: Cell membrane {ECO:00003... \n", - "4 SUBCELLULAR LOCATION: Secreted {ECO:0000303|Pu... \n", - "18 SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ... \n", - "35 SUBCELLULAR LOCATION: Cell membrane {ECO:00002... \n", - "... ... \n", - "11843 SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ... \n", - "11917 SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ... \n", - "11958 SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ... \n", - "11965 SUBCELLULAR LOCATION: Membrane {ECO:0000255}; ... \n", - "11968 SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ... \n", - "\n", - "[2579 rows x 3 columns]" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "membrane_df = df[membrane & ~cytosolic]\n", - "membrane_df" - ] - }, - { - "cell_type": "markdown", - "id": "77e8cea6", - "metadata": { - "id": "77e8cea6" - }, - "source": [ - "We're almost done! Now, let's make a list of sequences from each df and generate the associated labels. We'll use `0` as the label for cytosolic proteins and `1` as the label for membrane proteins.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "023ec31b", - "metadata": { - "id": "023ec31b" - }, - "outputs": [], - "source": [ - "cytosolic_sequences = cytosolic_df[\"Sequence\"].tolist()\n", - "cytosolic_labels = [0 for protein in cytosolic_sequences]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d0e7318b", - "metadata": { - "id": "d0e7318b" - }, - "outputs": [], - "source": [ - "membrane_sequences = membrane_df[\"Sequence\"].tolist()\n", - "membrane_labels = [1 for protein in membrane_sequences]" - ] - }, - { - "cell_type": "markdown", - "id": "5a4bbda2", - "metadata": { - "id": "5a4bbda2" - }, - "source": [ - "Now we can concatenate these lists together to get the `sequences` and `labels` lists that will form our final training data. Don't worry - they'll get shuffled during training!\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7dec7a4a", - "metadata": { - "id": "7dec7a4a", - "outputId": "960e5686-900f-48f0-bf4c-f848b3e38224" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sequences = cytosolic_sequences + membrane_sequences\n", - "labels = cytosolic_labels + membrane_labels\n", - "\n", - "# Quick check to make sure we got it right\n", - "len(sequences) == len(labels)" - ] - }, - { - "cell_type": "markdown", - "id": "bc782dd0", - "metadata": { - "id": "bc782dd0" - }, - "source": [ - "Phew!\n" - ] - }, - { - "cell_type": "markdown", - "id": "e0aac39c", - "metadata": { - "id": "e0aac39c" - }, - "source": [ - "## Splitting the data\n" - ] - }, - { - "cell_type": "markdown", - "id": "a9099e7c", - "metadata": { - "id": "a9099e7c" - }, - "source": [ - "Since the data we're loading isn't prepared for us as a machine learning dataset, we'll have to split the data into train and test sets ourselves! We can use sklearn's function for that:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "366147ad", - "metadata": { - "id": "366147ad" - }, - "outputs": [], - "source": [ - "from sklearn.model_selection import train_test_split\n", - "\n", - "train_sequences, test_sequences, train_labels, test_labels = train_test_split(\n", - " sequences, labels, test_size=0.25, shuffle=True\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "7d29b4ed", - "metadata": { - "id": "7d29b4ed" - }, - "source": [ - "## Tokenizing the data\n" - ] - }, - { - "cell_type": "markdown", - "id": "c02baaf7", - "metadata": { - "id": "c02baaf7" - }, - "source": [ - "All inputs to neural nets must be numerical. The process of converting strings into numerical indices suitable for a neural net is called **tokenization**. For natural language this can be quite complex, as usually the network's vocabulary will not contain every possible word, which means the tokenizer must handle splitting rarer words into pieces, as well as all the complexities of capitalization and unicode characters and so on.\n", - "\n", - "With proteins, however, things are very easy. In protein language models, each amino acid is converted to a single token. Every model on `transformers` comes with an associated `tokenizer` that handles tokenization for it, and protein language models are no different. Let's get our tokenizer!\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ddbe2b2d", - "metadata": { - "colab": { - "referenced_widgets": [ - "cc61f599adc641da8d40eefac0179aa3", - "4e3d35bc47c74852a5256f5b34312030", - "8b7f3b4a47d8437f9e9cc6648cfc8984" - ] - }, - "id": "ddbe2b2d", - "outputId": "88b04123-b40d-40c1-9d2f-d6b39dd1b3fa" - }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "cc61f599adc641da8d40eefac0179aa3", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Downloading: 0%| | 0.00/40.0 [00:00\n", - " \n", - " \n", - " [1428/1428 04:35, Epoch 3/3]\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
EpochTraining LossValidation LossAccuracy
1No log0.2109210.936958
20.2321000.2050990.944050
30.1457000.2000190.946414

" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "***** Running Evaluation *****\n", - " Num examples = 1269\n", - " Batch size = 8\n", - "Saving model checkpoint to esm2_t12_35M_UR50D-finetuned-localization/checkpoint-476\n", - "Configuration saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-476/config.json\n", - "Model weights saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-476/pytorch_model.bin\n", - "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-476/tokenizer_config.json\n", - "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-476/special_tokens_map.json\n", - "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-localization/tokenizer_config.json\n", - "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-localization/special_tokens_map.json\n", - "***** Running Evaluation *****\n", - " Num examples = 1269\n", - " Batch size = 8\n", - "Saving model checkpoint to esm2_t12_35M_UR50D-finetuned-localization/checkpoint-952\n", - "Configuration saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-952/config.json\n", - "Model weights saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-952/pytorch_model.bin\n", - "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-952/tokenizer_config.json\n", - "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-952/special_tokens_map.json\n", - "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-localization/tokenizer_config.json\n", - "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-localization/special_tokens_map.json\n", - "***** Running Evaluation *****\n", - " Num examples = 1269\n", - " Batch size = 8\n", - "Saving model checkpoint to esm2_t12_35M_UR50D-finetuned-localization/checkpoint-1428\n", - "Configuration saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-1428/config.json\n", - "Model weights saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-1428/pytorch_model.bin\n", - "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-1428/tokenizer_config.json\n", - "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-1428/special_tokens_map.json\n", - "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-localization/tokenizer_config.json\n", - "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-localization/special_tokens_map.json\n", - "\n", - "\n", - "Training completed. Do not forget to share your model on huggingface.co/models =)\n", - "\n", - "\n", - "Loading best model from esm2_t12_35M_UR50D-finetuned-localization/checkpoint-1428 (score: 0.946414499605989).\n" - ] - }, - { - "data": { - "text/plain": [ - "TrainOutput(global_step=1428, training_loss=0.1632746127473206, metrics={'train_runtime': 281.8102, 'train_samples_per_second': 40.506, 'train_steps_per_second': 5.067, 'total_flos': 1032423103475172.0, 'train_loss': 0.1632746127473206, 'epoch': 3.0})" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "trainer.train()" - ] - }, - { - "cell_type": "markdown", - "id": "dfec59f4", - "metadata": { - "id": "dfec59f4" - }, - "source": [ - "Nice! After three epochs we have a model accuracy of ~94%. Note that we didn't do a lot of work to filter the training data or tune hyperparameters for this experiment, and also that we used one of the smallest ESM-2 models. With a larger starting model and more effort to ensure that the training data categories were cleanly separable, accuracy could almost certainly go a lot higher!\n" - ] - }, - { - "cell_type": "markdown", - "id": "bc2ef458", - "metadata": { - "id": "bc2ef458" - }, - "source": [ - "---\n", - "\n", - "# Token classification\n" - ] - }, - { - "cell_type": "markdown", - "id": "78d701ed", - "metadata": { - "id": "78d701ed" - }, - "source": [ - "Another common language model task is **token classification**. In this task, instead of classifying the whole sequence into a single category, we categorize each token (amino acid, in this case!) into one or more categories. This kind of model could be useful for:\n", - "\n", - "- Predicting secondary structure\n", - "- Predicting buried vs. exposed residues\n", - "- Predicting residues that will receive post-translational modifications\n", - "- Predicting residues involved in binding pockets or active sites\n", - "- Probably several other things, it's been a while since I was a postdoc\n" - ] - }, - { - "cell_type": "markdown", - "id": "20e00afe", - "metadata": { - "id": "20e00afe" - }, - "source": [ - "## Data preparation\n" - ] - }, - { - "cell_type": "markdown", - "id": "f1b9e75c", - "metadata": { - "id": "f1b9e75c" - }, - "source": [ - "In this section, we're going to gather some training data from UniProt. As in the sequence classification example, we aim to create two lists: `sequences` and `labels`. Unlike in that example, however, the `labels` are more than just single integers. Instead, the label for each sample will be **one integer per token in the input**. This should make sense - when we do token classification, different tokens in the input may have different categories!\n", - "\n", - "To demonstrate token classification, we're going to go back to UniProt and get some data on protein secondary structures. As above, this will probably the main section you want to change when adapting this code to your own problems.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bf52cfb8", - "metadata": { - "id": "bf52cfb8" - }, - "outputs": [], - "source": [ - "import requests\n", - "\n", - "query_url = \"https://rest.uniprot.org/uniprotkb/stream?compressed=true&fields=accession%2Csequence%2Cft_strand%2Cft_helix&format=tsv&query=%28%28organism_id%3A9606%29%20AND%20%28reviewed%3Atrue%29%20AND%20%28length%3A%5B80%20TO%20500%5D%29%29\"" - ] - }, - { - "cell_type": "markdown", - "id": "73c902be", - "metadata": { - "id": "73c902be" - }, - "source": [ - "This time, our UniProt search was `(organism_id:9606) AND (reviewed:true) AND (length:[100 TO 1000])` as it was in the first example, but instead of `Subcellular location [CC]` we take the `Helix` and `Beta strand` columns, as they contain the secondary structure information we want.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "be65f529", - "metadata": { - "id": "be65f529" - }, - "outputs": [], - "source": [ - "uniprot_request = requests.get(query_url)" - ] - }, - { - "cell_type": "markdown", - "id": "3f683dd7", - "metadata": { - "id": "3f683dd7" - }, - "source": [ - "To get this data into Pandas, we use a `BytesIO` object, which Pandas will treat like a file. If you downloaded the data as a file you can skip this bit and just pass the filepath directly to `read_csv`.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f49439ab", - "metadata": { - "id": "f49439ab", - "outputId": "1de4fc57-0ce6-4a70-df61-53e9b9980948", - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/html": [ - "

\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
EntrySequenceBeta strandHelix
0A0A0K2S4Q6MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...NaNNaN
1A0A5B9DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV...STRAND 9..14; /evidence=\"ECO:0007829|PDB:4UDT\"...HELIX 2..4; /evidence=\"ECO:0007829|PDB:4UDT\"; ...
2A0AVI4MDSPEVTFTLAYLVFAVCFVFTPNEFHAAGLTVQNLLSGWLGSEDA...NaNNaN
3A0JLT2MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP...STRAND 79..81; /evidence=\"ECO:0007829|PDB:7EMF\"HELIX 83..86; /evidence=\"ECO:0007829|PDB:7EMF\"...
4A0M8Q6GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG...NaNNaN
...............
11977Q9NZ38MAFPGQSDTKMQWPEVPALPLLSSLCMAMVRKSSALGKEVGRRSEG...NaNNaN
11978Q9UFV3MAETYRRSRQHEQLPGQRHMDLLTGYSKLIQSRLKLLLHLGSQPPV...NaNNaN
11979Q9Y6C7MAHHSLNTFYIWHNNVLHTHLVFFLPHLLNQPFSRGSFLIWLLLCW...NaNNaN
11980X6R8D5MGRKEHESPSQPHMCGWEDSQKPSVPSHGPKTPSCKGVKAPHSSRP...NaNNaN
11981X6R8R1MGVVLSPHPAPSRREPLAPLAPGTRPGWSPAVSGSSRSALRPSTAG...NaNNaN
\n", - "

11982 rows × 4 columns

\n", - "
" - ], - "text/plain": [ - " Entry Sequence \\\n", - "0 A0A0K2S4Q6 MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE... \n", - "1 A0A5B9 DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV... \n", - "2 A0AVI4 MDSPEVTFTLAYLVFAVCFVFTPNEFHAAGLTVQNLLSGWLGSEDA... \n", - "3 A0JLT2 MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP... \n", - "4 A0M8Q6 GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG... \n", - "... ... ... \n", - "11977 Q9NZ38 MAFPGQSDTKMQWPEVPALPLLSSLCMAMVRKSSALGKEVGRRSEG... \n", - "11978 Q9UFV3 MAETYRRSRQHEQLPGQRHMDLLTGYSKLIQSRLKLLLHLGSQPPV... \n", - "11979 Q9Y6C7 MAHHSLNTFYIWHNNVLHTHLVFFLPHLLNQPFSRGSFLIWLLLCW... \n", - "11980 X6R8D5 MGRKEHESPSQPHMCGWEDSQKPSVPSHGPKTPSCKGVKAPHSSRP... \n", - "11981 X6R8R1 MGVVLSPHPAPSRREPLAPLAPGTRPGWSPAVSGSSRSALRPSTAG... \n", - "\n", - " Beta strand \\\n", - "0 NaN \n", - "1 STRAND 9..14; /evidence=\"ECO:0007829|PDB:4UDT\"... \n", - "2 NaN \n", - "3 STRAND 79..81; /evidence=\"ECO:0007829|PDB:7EMF\" \n", - "4 NaN \n", - "... ... \n", - "11977 NaN \n", - "11978 NaN \n", - "11979 NaN \n", - "11980 NaN \n", - "11981 NaN \n", - "\n", - " Helix \n", - "0 NaN \n", - "1 HELIX 2..4; /evidence=\"ECO:0007829|PDB:4UDT\"; ... \n", - "2 NaN \n", - "3 HELIX 83..86; /evidence=\"ECO:0007829|PDB:7EMF\"... \n", - "4 NaN \n", - "... ... \n", - "11977 NaN \n", - "11978 NaN \n", - "11979 NaN \n", - "11980 NaN \n", - "11981 NaN \n", - "\n", - "[11982 rows x 4 columns]" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from io import BytesIO\n", - "import pandas\n", - "\n", - "bio = BytesIO(uniprot_request.content)\n", - "\n", - "df = pandas.read_csv(bio, compression=\"gzip\", sep=\"\\t\")\n", - "df" - ] - }, - { - "cell_type": "markdown", - "id": "736010f0", - "metadata": { - "id": "736010f0" - }, - "source": [ - "Since not all proteins have this structural information, we discard proteins that have no annotated beta strands or alpha helices.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "39ce9a5c", - "metadata": { - "id": "39ce9a5c", - "outputId": "4334cbe5-a43e-4c6a-ff32-2cc2af9ae864" - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
EntrySequenceBeta strandHelix
1A0A5B9DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV...STRAND 9..14; /evidence=\"ECO:0007829|PDB:4UDT\"...HELIX 2..4; /evidence=\"ECO:0007829|PDB:4UDT\"; ...
3A0JLT2MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP...STRAND 79..81; /evidence=\"ECO:0007829|PDB:7EMF\"HELIX 83..86; /evidence=\"ECO:0007829|PDB:7EMF\"...
14A1L3X0MAFSDLTSRTVHLYDNWIKDADPRVEDWLLMSSPLPQTILLGFYVY...STRAND 97..99; /evidence=\"ECO:0007829|PDB:6Y7F\"HELIX 17..20; /evidence=\"ECO:0007829|PDB:6Y7F\"...
16A1Z1Q3MYPSNKKKKVWREEKERLLKMTLEERRKEYLRDYIPLNSILSWKEE...STRAND 71..77; /evidence=\"ECO:0007829|PDB:4IQY...HELIX 11..19; /evidence=\"ECO:0007829|PDB:4IQY\"...
20A2RUC4MAGQHLPVPRLEGVSREQFMQHLYPQRKPLVLEGIDLGPCTSKWTV...STRAND 10..13; /evidence=\"ECO:0007829|PDB:3AL5...HELIX 16..22; /evidence=\"ECO:0007829|PDB:3AL5\"...
...............
11551Q96I45MVNLGLSRVDDAVAAKHPGLGEYAACQSHAFMKGVFTFVTGTGMAF...STRAND 3..5; /evidence=\"ECO:0007829|PDB:2LOR\";...HELIX 6..16; /evidence=\"ECO:0007829|PDB:2LOR\";...
11614Q9H0W7MPTNCAAAGCATTYNKHINISFHRFPLDPKRRKEWVRLVRRKNFVP...STRAND 7..9; /evidence=\"ECO:0007829|PDB:2D8R\";...HELIX 29..38; /evidence=\"ECO:0007829|PDB:2D8R\"
11659Q9P1F3MNVDHEVNLLVEEIHRLGSKNADGKLSVKFGVLFRDDKCANLFEAL...STRAND 24..29; /evidence=\"ECO:0007829|PDB:2L2O...HELIX 3..17; /evidence=\"ECO:0007829|PDB:2L2O\";...
11661Q9P298MSANRRWWVPPDDEDCVSEKLLRKTRESPLVPIGLGGCLVVAAYRI...STRAND 11..14; /evidence=\"ECO:0007829|PDB:2LON...HELIX 18..24; /evidence=\"ECO:0007829|PDB:2LON\"...
11668Q9UIY3MSASVKESLQLQLLEMEMLFSMFPNQGEVKLEDVNALTNIKRYLEG...STRAND 28..32; /evidence=\"ECO:0007829|PDB:2DAW...HELIX 5..22; /evidence=\"ECO:0007829|PDB:2DAW\";...
\n", - "

3911 rows × 4 columns

\n", - "
" - ], - "text/plain": [ - " Entry Sequence \\\n", - "1 A0A5B9 DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV... \n", - "3 A0JLT2 MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP... \n", - "14 A1L3X0 MAFSDLTSRTVHLYDNWIKDADPRVEDWLLMSSPLPQTILLGFYVY... \n", - "16 A1Z1Q3 MYPSNKKKKVWREEKERLLKMTLEERRKEYLRDYIPLNSILSWKEE... \n", - "20 A2RUC4 MAGQHLPVPRLEGVSREQFMQHLYPQRKPLVLEGIDLGPCTSKWTV... \n", - "... ... ... \n", - "11551 Q96I45 MVNLGLSRVDDAVAAKHPGLGEYAACQSHAFMKGVFTFVTGTGMAF... \n", - "11614 Q9H0W7 MPTNCAAAGCATTYNKHINISFHRFPLDPKRRKEWVRLVRRKNFVP... \n", - "11659 Q9P1F3 MNVDHEVNLLVEEIHRLGSKNADGKLSVKFGVLFRDDKCANLFEAL... \n", - "11661 Q9P298 MSANRRWWVPPDDEDCVSEKLLRKTRESPLVPIGLGGCLVVAAYRI... \n", - "11668 Q9UIY3 MSASVKESLQLQLLEMEMLFSMFPNQGEVKLEDVNALTNIKRYLEG... \n", - "\n", - " Beta strand \\\n", - "1 STRAND 9..14; /evidence=\"ECO:0007829|PDB:4UDT\"... \n", - "3 STRAND 79..81; /evidence=\"ECO:0007829|PDB:7EMF\" \n", - "14 STRAND 97..99; /evidence=\"ECO:0007829|PDB:6Y7F\" \n", - "16 STRAND 71..77; /evidence=\"ECO:0007829|PDB:4IQY... \n", - "20 STRAND 10..13; /evidence=\"ECO:0007829|PDB:3AL5... \n", - "... ... \n", - "11551 STRAND 3..5; /evidence=\"ECO:0007829|PDB:2LOR\";... \n", - "11614 STRAND 7..9; /evidence=\"ECO:0007829|PDB:2D8R\";... \n", - "11659 STRAND 24..29; /evidence=\"ECO:0007829|PDB:2L2O... \n", - "11661 STRAND 11..14; /evidence=\"ECO:0007829|PDB:2LON... \n", - "11668 STRAND 28..32; /evidence=\"ECO:0007829|PDB:2DAW... \n", - "\n", - " Helix \n", - "1 HELIX 2..4; /evidence=\"ECO:0007829|PDB:4UDT\"; ... \n", - "3 HELIX 83..86; /evidence=\"ECO:0007829|PDB:7EMF\"... \n", - "14 HELIX 17..20; /evidence=\"ECO:0007829|PDB:6Y7F\"... \n", - "16 HELIX 11..19; /evidence=\"ECO:0007829|PDB:4IQY\"... \n", - "20 HELIX 16..22; /evidence=\"ECO:0007829|PDB:3AL5\"... \n", - "... ... \n", - "11551 HELIX 6..16; /evidence=\"ECO:0007829|PDB:2LOR\";... \n", - "11614 HELIX 29..38; /evidence=\"ECO:0007829|PDB:2D8R\" \n", - "11659 HELIX 3..17; /evidence=\"ECO:0007829|PDB:2L2O\";... \n", - "11661 HELIX 18..24; /evidence=\"ECO:0007829|PDB:2LON\"... \n", - "11668 HELIX 5..22; /evidence=\"ECO:0007829|PDB:2DAW\";... \n", - "\n", - "[3911 rows x 4 columns]" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "no_structure_rows = df[\"Beta strand\"].isna() & df[\"Helix\"].isna()\n", - "df = df[~no_structure_rows]\n", - "df" - ] - }, - { - "cell_type": "markdown", - "id": "f43e372c", - "metadata": { - "id": "f43e372c" - }, - "source": [ - "Well, this works, but that data still isn't in a clean format that we can use to build our labels. Let's take a look at one sample to see what exactly we're dealing with:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "73e99d1b", - "metadata": { - "id": "73e99d1b", - "outputId": "05cbcf3e-9d0a-49fd-9c63-66a1d9fe64d9" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "'HELIX 2..4; /evidence=\"ECO:0007829|PDB:4UDT\"; HELIX 17..23; /evidence=\"ECO:0007829|PDB:4UDT\"; HELIX 83..86; /evidence=\"ECO:0007829|PDB:4UDT\"'" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.iloc[0][\"Helix\"]" - ] - }, - { - "cell_type": "markdown", - "id": "6cd5160a", - "metadata": { - "id": "6cd5160a" - }, - "source": [ - "We'll need to use a [regex](https://docs.python.org/3/howto/regex.html) to pull out each segment that's marked as being a STRAND or HELIX. What we're asking for is a list of everywhere we see the word STRAND or HELIX followed by two numbers separated by two dots. In each case where this pattern is found, we tell the regex to extract the two numbers as a tuple for us.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7540949e", - "metadata": { - "id": "7540949e", - "outputId": "aeb08e47-a28a-4b22-fd39-95a21d987a16" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[('2', '4'), ('17', '23'), ('83', '86')]" - ] - }, - "execution_count": 28, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import re\n", - "\n", - "strand_re = r\"STRAND\\s(\\d+)\\.\\.(\\d+)\\;\"\n", - "helix_re = r\"HELIX\\s(\\d+)\\.\\.(\\d+)\\;\"\n", - "\n", - "re.findall(helix_re, df.iloc[0][\"Helix\"])" - ] - }, - { - "cell_type": "markdown", - "id": "4457b1a0", - "metadata": { - "id": "4457b1a0" - }, - "source": [ - "Looks good! We can use this to build our training data. Recall that the **labels** need to be a list or array of integers that's the same length as the input sequence. We're going to use 0 to indicate residues without any annotated structure, 1 for residues in an alpha helix, and 2 for residues in a beta strand. To build that, we'll start with an array of all 0s, and then fill in values based on the positions that our regex pulls out of the UniProt results.\n", - "\n", - "We'll use NumPy arrays rather than lists here, since these allow [slice assignment](https://numpy.org/doc/stable/user/basics.indexing.html#assigning-values-to-indexed-arrays), which will be a lot simpler than editing a list of integers. Note also that UniProt annotates residues starting from 1 (unlike Python, which starts from 0), and region annotations are inclusive (so 1..3 means residues 1, 2 and 3). To turn these into Python slices, we subtract 1 from the start of each annotation, but not the end.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a4c97179", - "metadata": { - "id": "a4c97179" - }, - "outputs": [], - "source": [ - "import numpy as np\n", - "\n", - "\n", - "def build_labels(sequence, strands, helices):\n", - " # Start with all 0s\n", - " labels = np.zeros(len(sequence), dtype=np.int64)\n", - "\n", - " if isinstance(helices, float): # Indicates missing (NaN)\n", - " found_helices = []\n", - " else:\n", - " found_helices = re.findall(helix_re, helices)\n", - " for helix_start, helix_end in found_helices:\n", - " helix_start = int(helix_start) - 1\n", - " helix_end = int(helix_end)\n", - " assert helix_end <= len(sequence)\n", - " labels[helix_start:helix_end] = 1 # Helix category\n", - "\n", - " if isinstance(strands, float): # Indicates missing (NaN)\n", - " found_strands = []\n", - " else:\n", - " found_strands = re.findall(strand_re, strands)\n", - " for strand_start, strand_end in found_strands:\n", - " strand_start = int(strand_start) - 1\n", - " strand_end = int(strand_end)\n", - " assert strand_end <= len(sequence)\n", - " labels[strand_start:strand_end] = 2 # Strand category\n", - " return labels" - ] - }, - { - "cell_type": "markdown", - "id": "5ad7e7fd", - "metadata": { - "id": "5ad7e7fd" - }, - "source": [ - "Now we've defined a helper function, let's build our lists of sequences and labels:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "313811fe", - "metadata": { - "id": "313811fe" - }, - "outputs": [], - "source": [ - "sequences = []\n", - "labels = []\n", - "\n", - "for row_idx, row in df.iterrows():\n", - " row_labels = build_labels(row[\"Sequence\"], row[\"Beta strand\"], row[\"Helix\"])\n", - " sequences.append(row[\"Sequence\"])\n", - " labels.append(row_labels)" - ] - }, - { - "cell_type": "markdown", - "id": "8e8b3ba8", - "metadata": { - "id": "8e8b3ba8" - }, - "source": [ - "## Creating our dataset\n" - ] - }, - { - "cell_type": "markdown", - "id": "e619d9ae", - "metadata": { - "id": "e619d9ae" - }, - "source": [ - "Nice! Now we'll split and tokenize the data, and then create datasets - I'll go through this quite quickly here, since it's identical to how we did it in the sequence classification example above.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3c208c30", - "metadata": { - "id": "3c208c30" - }, - "outputs": [], - "source": [ - "from sklearn.model_selection import train_test_split\n", - "\n", - "train_sequences, test_sequences, train_labels, test_labels = train_test_split(\n", - " sequences, labels, test_size=0.25, shuffle=True\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2182fae2", - "metadata": { - "id": "2182fae2", - "outputId": "af9ca362-41c5-423e-9ab2-467546769285" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "loading file vocab.txt from cache at /home/matt/.cache/huggingface/hub/models--facebook--esm2_t12_35M_UR50D/snapshots/dbb5b2b74bf5bd9cd0ab5c2b95ef3994f69879a3/vocab.txt\n", - "loading file added_tokens.json from cache at None\n", - "loading file special_tokens_map.json from cache at /home/matt/.cache/huggingface/hub/models--facebook--esm2_t12_35M_UR50D/snapshots/dbb5b2b74bf5bd9cd0ab5c2b95ef3994f69879a3/special_tokens_map.json\n", - "loading file tokenizer_config.json from cache at /home/matt/.cache/huggingface/hub/models--facebook--esm2_t12_35M_UR50D/snapshots/dbb5b2b74bf5bd9cd0ab5c2b95ef3994f69879a3/tokenizer_config.json\n" - ] - } - ], - "source": [ - "from transformers import AutoTokenizer\n", - "\n", - "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n", - "\n", - "train_tokenized = tokenizer(train_sequences)\n", - "test_tokenized = tokenizer(test_sequences)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3939f13a", - "metadata": { - "id": "3939f13a" - }, - "outputs": [], - "source": [ - "from datasets import Dataset\n", - "\n", - "train_dataset = Dataset.from_dict(train_tokenized)\n", - "test_dataset = Dataset.from_dict(test_tokenized)\n", - "\n", - "train_dataset = train_dataset.add_column(\"labels\", train_labels)\n", - "test_dataset = test_dataset.add_column(\"labels\", test_labels)" - ] - }, - { - "cell_type": "markdown", - "id": "4766fe4b", - "metadata": { - "id": "4766fe4b" - }, - "source": [ - "## Model loading\n" - ] - }, - { - "cell_type": "markdown", - "id": "de8419b5", - "metadata": { - "id": "de8419b5" - }, - "source": [ - "The key difference here with the above example is that we use `AutoModelForTokenClassification` instead of `AutoModelForSequenceClassification`. We will also need a `data_collator` this time, as we're in the slightly more complex case where both inputs and labels must be padded in each batch.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4b26b828", - "metadata": { - "id": "4b26b828", - "outputId": "b0ab3f12-ba38-4bbb-f456-5c01c80711fb" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "loading configuration file config.json from cache at /home/matt/.cache/huggingface/hub/models--facebook--esm2_t12_35M_UR50D/snapshots/dbb5b2b74bf5bd9cd0ab5c2b95ef3994f69879a3/config.json\n", - "Model config EsmConfig {\n", - " \"_name_or_path\": \"facebook/esm2_t12_35M_UR50D\",\n", - " \"architectures\": [\n", - " \"EsmForMaskedLM\"\n", - " ],\n", - " \"attention_probs_dropout_prob\": 0.0,\n", - " \"classifier_dropout\": null,\n", - " \"emb_layer_norm_before\": false,\n", - " \"esmfold_config\": null,\n", - " \"hidden_act\": \"gelu\",\n", - " \"hidden_dropout_prob\": 0.0,\n", - " \"hidden_size\": 480,\n", - " \"id2label\": {\n", - " \"0\": \"LABEL_0\",\n", - " \"1\": \"LABEL_1\",\n", - " \"2\": \"LABEL_2\"\n", - " },\n", - " \"initializer_range\": 0.02,\n", - " \"intermediate_size\": 1920,\n", - " \"is_folding_model\": false,\n", - " \"label2id\": {\n", - " \"LABEL_0\": 0,\n", - " \"LABEL_1\": 1,\n", - " \"LABEL_2\": 2\n", - " },\n", - " \"layer_norm_eps\": 1e-05,\n", - " \"mask_token_id\": 32,\n", - " \"max_position_embeddings\": 1026,\n", - " \"model_type\": \"esm\",\n", - " \"num_attention_heads\": 20,\n", - " \"num_hidden_layers\": 12,\n", - " \"pad_token_id\": 1,\n", - " \"position_embedding_type\": \"rotary\",\n", - " \"token_dropout\": true,\n", - " \"torch_dtype\": \"float32\",\n", - " \"transformers_version\": \"4.25.0.dev0\",\n", - " \"use_cache\": true,\n", - " \"vocab_list\": null,\n", - " \"vocab_size\": 33\n", - "}\n", - "\n", - "loading weights file pytorch_model.bin from cache at /home/matt/.cache/huggingface/hub/models--facebook--esm2_t12_35M_UR50D/snapshots/dbb5b2b74bf5bd9cd0ab5c2b95ef3994f69879a3/pytorch_model.bin\n", - "Some weights of the model checkpoint at facebook/esm2_t12_35M_UR50D were not used when initializing EsmForTokenClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.decoder.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias']\n", - "- This IS expected if you are initializing EsmForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", - "- This IS NOT expected if you are initializing EsmForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", - "Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.weight', 'classifier.bias']\n", - "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" - ] - } - ], - "source": [ - "from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer\n", - "\n", - "num_labels = 3\n", - "model = AutoModelForTokenClassification.from_pretrained(\n", - " model_checkpoint, num_labels=num_labels\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "eec0005a", - "metadata": { - "id": "eec0005a" - }, - "outputs": [], - "source": [ - "from transformers import DataCollatorForTokenClassification\n", - "\n", - "data_collator = DataCollatorForTokenClassification(tokenizer)" - ] - }, - { - "cell_type": "markdown", - "id": "bd3c7305", - "metadata": { - "id": "bd3c7305" - }, - "source": [ - "Now we set up our `TrainingArguments` as before.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e7724323", - "metadata": { - "id": "e7724323", - "outputId": "dfd30ff9-4b83-43c1-c9cb-35be6f9ffcaa" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "PyTorch: setting up devices\n", - "The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n" - ] - } - ], - "source": [ - "model_name = model_checkpoint.split(\"/\")[-1]\n", - "batch_size = 8\n", - "\n", - "args = TrainingArguments(\n", - " f\"{model_name}-finetuned-secondary-structure\",\n", - " evaluation_strategy=\"epoch\",\n", - " save_strategy=\"epoch\",\n", - " learning_rate=1e-4,\n", - " per_device_train_batch_size=batch_size,\n", - " per_device_eval_batch_size=batch_size,\n", - " num_train_epochs=3,\n", - " weight_decay=0.001,\n", - " load_best_model_at_end=True,\n", - " metric_for_best_model=\"accuracy\",\n", - " push_to_hub=True,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "fb5fba9a", - "metadata": { - "id": "fb5fba9a" - }, - "source": [ - "Our `compute_metrics` function is a bit more complex than in the sequence classification task, as we need to ignore padding tokens (those where the label is `-100`).\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "736886a0", - "metadata": { - "id": "736886a0" - }, - "outputs": [], - "source": [ - "from evaluate import load\n", - "import numpy as np\n", - "\n", - "metric = load(\"accuracy\")\n", - "\n", - "\n", - "def compute_metrics(eval_pred):\n", - " predictions, labels = eval_pred\n", - " labels = labels.reshape((-1,))\n", - " predictions = np.argmax(predictions, axis=2)\n", - " predictions = predictions.reshape((-1,))\n", - " predictions = predictions[labels != -100]\n", - " labels = labels[labels != -100]\n", - " return metric.compute(predictions=predictions, references=labels)" - ] - }, - { - "cell_type": "markdown", - "id": "37491af5", - "metadata": { - "id": "37491af5" - }, - "source": [ - "And now we're ready to train our model!\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4c97836c", - "metadata": { - "id": "4c97836c", - "outputId": "2a8b355f-b9ab-4c36-ca4f-8cd303c62c6f" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/matt/PycharmProjects/notebooks/examples/esm2_t12_35M_UR50D-finetuned-secondary-structure is already a clone of https://huggingface.co/Rocketknight1/esm2_t12_35M_UR50D-finetuned-secondary-structure. Make sure you pull the latest changes with `repo.git_pull()`.\n", - "/home/matt/PycharmProjects/transformers/src/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", - " warnings.warn(\n", - "***** Running training *****\n", - " Num examples = 2933\n", - " Num Epochs = 3\n", - " Instantaneous batch size per device = 8\n", - " Total train batch size (w. parallel, distributed & accumulation) = 8\n", - " Gradient Accumulation steps = 1\n", - " Total optimization steps = 1101\n", - " Number of trainable parameters = 33763203\n", - "Automatic Weights & Biases logging enabled, to disable set os.environ[\"WANDB_DISABLED\"] = \"true\"\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [1101/1101 03:52, Epoch 3/3]\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
EpochTraining LossValidation LossAccuracy
1No log0.4654080.809475
20.4962000.4439260.818526
30.3711000.4491090.821522

" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "***** Running Evaluation *****\n", - " Num examples = 978\n", - " Batch size = 8\n", - "Saving model checkpoint to esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-367\n", - "Configuration saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-367/config.json\n", - "Model weights saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-367/pytorch_model.bin\n", - "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-367/tokenizer_config.json\n", - "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-367/special_tokens_map.json\n", - "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/tokenizer_config.json\n", - "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/special_tokens_map.json\n", - "***** Running Evaluation *****\n", - " Num examples = 978\n", - " Batch size = 8\n", - "Saving model checkpoint to esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-734\n", - "Configuration saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-734/config.json\n", - "Model weights saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-734/pytorch_model.bin\n", - "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-734/tokenizer_config.json\n", - "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-734/special_tokens_map.json\n", - "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/tokenizer_config.json\n", - "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/special_tokens_map.json\n", - "***** Running Evaluation *****\n", - " Num examples = 978\n", - " Batch size = 8\n", - "Saving model checkpoint to esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-1101\n", - "Configuration saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-1101/config.json\n", - "Model weights saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-1101/pytorch_model.bin\n", - "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-1101/tokenizer_config.json\n", - "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-1101/special_tokens_map.json\n", - "tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/tokenizer_config.json\n", - "Special tokens file saved in esm2_t12_35M_UR50D-finetuned-secondary-structure/special_tokens_map.json\n", - "\n", - "\n", - "Training completed. Do not forget to share your model on huggingface.co/models =)\n", - "\n", - "\n", - "Loading best model from esm2_t12_35M_UR50D-finetuned-secondary-structure/checkpoint-1101 (score: 0.8215224822508546).\n" - ] - }, - { - "data": { - "text/plain": [ - "TrainOutput(global_step=1101, training_loss=0.42545173083728927, metrics={'train_runtime': 232.9156, 'train_samples_per_second': 37.778, 'train_steps_per_second': 4.727, 'total_flos': 794586720601188.0, 'train_loss': 0.42545173083728927, 'epoch': 3.0})" - ] - }, - "execution_count": 38, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "trainer = Trainer(\n", - " model,\n", - " args,\n", - " train_dataset=train_dataset,\n", - " eval_dataset=test_dataset,\n", - " tokenizer=tokenizer,\n", - " compute_metrics=compute_metrics,\n", - " data_collator=data_collator,\n", - ")\n", - "\n", - "trainer.train()" - ] - }, - { - "cell_type": "markdown", - "id": "f503fc00", - "metadata": { - "id": "f503fc00" - }, - "source": [ - "This definitely seems harder than the first task, but we still attain a very respectable accuracy. Remember that to keep this demo lightweight, we used one of the smallest ESM models, focused on human proteins only and didn't put a lot of work into making sure we only included completely-annotated proteins in our training set. With a bigger model and a cleaner, broader training set, accuracy on this task could definitely go a lot higher!\n" - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.8" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 1b6ad764b73edd25b1a06f550531a96787f62fa7 Mon Sep 17 00:00:00 2001 From: ElektrikSpark Date: Sat, 19 Oct 2024 12:05:03 +0000 Subject: [PATCH 2/5] feat: netmhcpan works --- apps/fastapi/.env.example | 14 +- apps/fastapi/src/app/core/config.py | 10 +- apps/fastapi/src/app/core/utils.py | 17 +- apps/fastapi/src/app/services/evaluation.py | 4 +- apps/fastapi/src/app/services/inference.py | 122 +- apps/fastapi/src/app/services/postprocess.py | 84 +- .../tables/mhc-peptide-dialog-cell.tsx | 30 +- apps/tools-fastapi/.dockerignore | 37 + apps/tools-fastapi/.env.example | 1 + apps/tools-fastapi/.gitignore | 179 ++ apps/tools-fastapi/.python-version | 1 + apps/tools-fastapi/.vscode/launch.json | 13 + apps/tools-fastapi/.vscode/settings.json | 5 + apps/tools-fastapi/netMHCIIpan-4.3.readme | 142 ++ apps/tools-fastapi/netMHCpan-4.1.readme | 186 ++ apps/tools-fastapi/package.json | 8 + apps/tools-fastapi/pyproject.toml | 29 + apps/tools-fastapi/requirements.txt | 311 +++ apps/tools-fastapi/src/app/__init__.py | 1 + apps/tools-fastapi/src/app/api/__init__.py | 1 + .../src/app/api/api_v1/__init__.py | 0 apps/tools-fastapi/src/app/api/api_v1/api.py | 13 + .../src/app/api/api_v1/endpoints/__init__.py | 0 .../app/api/api_v1/endpoints/prediction.py | 46 + apps/tools-fastapi/src/app/core/__init__.py | 0 apps/tools-fastapi/src/app/core/config.py | 49 + apps/tools-fastapi/src/app/core/utils.py | 3 + apps/tools-fastapi/src/app/main.py | 65 + .../src/app/services/__init__.py | 0 .../src/app/services/inference.py | 114 ++ apps/tools-fastapi/uv.lock | 1663 +++++++++++++++++ 31 files changed, 3006 insertions(+), 142 deletions(-) create mode 100644 apps/tools-fastapi/.dockerignore create mode 100644 apps/tools-fastapi/.env.example create mode 100644 apps/tools-fastapi/.gitignore create mode 100644 apps/tools-fastapi/.python-version create mode 100644 apps/tools-fastapi/.vscode/launch.json create mode 100644 apps/tools-fastapi/.vscode/settings.json create mode 100644 apps/tools-fastapi/netMHCIIpan-4.3.readme create mode 100644 apps/tools-fastapi/netMHCpan-4.1.readme create mode 100644 apps/tools-fastapi/package.json create mode 100644 apps/tools-fastapi/pyproject.toml create mode 100644 apps/tools-fastapi/requirements.txt create mode 100644 apps/tools-fastapi/src/app/__init__.py create mode 100644 apps/tools-fastapi/src/app/api/__init__.py create mode 100644 apps/tools-fastapi/src/app/api/api_v1/__init__.py create mode 100644 apps/tools-fastapi/src/app/api/api_v1/api.py create mode 100644 apps/tools-fastapi/src/app/api/api_v1/endpoints/__init__.py create mode 100644 apps/tools-fastapi/src/app/api/api_v1/endpoints/prediction.py create mode 100644 apps/tools-fastapi/src/app/core/__init__.py create mode 100644 apps/tools-fastapi/src/app/core/config.py create mode 100644 apps/tools-fastapi/src/app/core/utils.py create mode 100644 apps/tools-fastapi/src/app/main.py create mode 100644 apps/tools-fastapi/src/app/services/__init__.py create mode 100644 apps/tools-fastapi/src/app/services/inference.py create mode 100644 apps/tools-fastapi/uv.lock diff --git a/apps/fastapi/.env.example b/apps/fastapi/.env.example index 9b6e0a7..9446a27 100644 --- a/apps/fastapi/.env.example +++ b/apps/fastapi/.env.example @@ -1,3 +1,10 @@ +BACKEND_CORS_ORIGINS="http://localhost:3000" + +# AWS +S3_BUCKET_NAME="" +SAGEMAKER_ENDPOINT_NAME="" +EC2_TOOLS_API_URL="" + # Supabase SUPABASE_URL="" SUPABASE_KEY="" @@ -5,13 +12,6 @@ JWT_SECRET="" SUPERUSER_EMAIL="zhouge1831@gmail.com" SUPERUSER_PASSWORD="Zz030327#" -# AWS -REGION="us-east-1" -PROJECT_NAME="" -RAW_BUCKET="" -ARTIFACTS_BUCKET="" -OUTPUT_BUCKET="" - # HuggingFace HUGGINGFACE_ACCESS_TOKEN="" diff --git a/apps/fastapi/src/app/core/config.py b/apps/fastapi/src/app/core/config.py index b0c6f55..1889c46 100644 --- a/apps/fastapi/src/app/core/config.py +++ b/apps/fastapi/src/app/core/config.py @@ -23,7 +23,7 @@ class Settings(BaseSettings): ENV: str = Field(default="", env="ENV") - REGION: str = Field(default="us-east-1", env="REGION") + AWS_REGION: str = Field(default="us-east-1", env="AWS_REGION") BACKEND_CORS_ORIGINS: Union[List[AnyHttpUrl], List[str]] = Field( default=["http://localhost:3000"], env="BACKEND_CORS_ORIGINS" ) @@ -48,6 +48,10 @@ def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str default="huggingface-pytorch-inference-2024-10-16-20-16-41-824", env="SAGEMAKER_ENDPOINT_NAME", ) + EC2_TOOLS_API_URL: str = Field( + ..., + env="EC2_TOOLS_API_URL", + ) # Optional HUGGINGFACE_ACCESS_TOKEN: str = Field(None, env="HUGGINGFACE_ACCESS_TOKEN") @@ -55,8 +59,8 @@ def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str # Project details API_VERSION: str = "/api/v1" - PROJECT_NAME: str = "FastAPI App" - PROJECT_DESCRIPTION: str = "A simple FastAPI app" + PROJECT_NAME: str = "B-cell and T-cell Epitope Prediction FastAPI" + PROJECT_DESCRIPTION: str = "B-cell and T-cell Epitope Prediction" # Pydantic configuration to load environment variables from .env model_config = SettingsConfigDict(env_file=".env") diff --git a/apps/fastapi/src/app/core/utils.py b/apps/fastapi/src/app/core/utils.py index 988f768..4713a62 100644 --- a/apps/fastapi/src/app/core/utils.py +++ b/apps/fastapi/src/app/core/utils.py @@ -62,7 +62,7 @@ def read_s3_csv( # CRUD Sagemaker Endpoints def get_endpoints(endpoint_name_filter, sagemaker_client=None): if sagemaker_client is None: - sagemaker_client = boto3.client("sagemaker", region_name=settings.REGION) + sagemaker_client = boto3.client("sagemaker", region_name=settings.AWS_REGION) # Retrieve all endpoints for filtered name response = sagemaker_client.list_endpoints( SortBy="Name", NameContains=endpoint_name_filter, MaxResults=100 @@ -87,7 +87,7 @@ def get_endpoints(endpoint_name_filter, sagemaker_client=None): def get_endpoint(endpoint_name_filter, sagemaker_client=None): if sagemaker_client is None: - sagemaker_client = boto3.client("sagemaker", region_name=settings.REGION) + sagemaker_client = boto3.client("sagemaker", region_name=settings.AWS_REGION) endpoints = get_endpoints(endpoint_name_filter, sagemaker_client=sagemaker_client) if len(endpoints) == 0: return None @@ -230,11 +230,16 @@ async def upload_csv_to_s3(results: List[T], s3_key: str): csv_content = output.getvalue() - # Upload the CSV to S3 - async with aioboto3.client("s3", region_name=settings.AWS_REGION) as s3_client: + # Create an aioboto3 session + session = aioboto3.Session() + + # Use the session to create an S3 client with async context manager + async with session.client("s3", region_name=settings.AWS_REGION) as s3_client: try: await s3_client.put_object( Bucket=settings.S3_BUCKET_NAME, Key=s3_key, Body=csv_content ) - except Exception: - raise HTTPException(status_code=500, detail="Failed to upload CSV to S3") + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to upload CSV to S3: {str(e)}" + ) diff --git a/apps/fastapi/src/app/services/evaluation.py b/apps/fastapi/src/app/services/evaluation.py index cce4360..917c1af 100644 --- a/apps/fastapi/src/app/services/evaluation.py +++ b/apps/fastapi/src/app/services/evaluation.py @@ -10,7 +10,9 @@ def get_predictions(requests, endpoint_name, model_name, model_type="mme"): """ Pass preprocessed requests to the specific endpoint (mme or single) """ - sagemaker_runtime = boto3.client("sagemaker-runtime", region_name=settings.REGION) + sagemaker_runtime = boto3.client( + "sagemaker-runtime", region_name=settings.AWS_REGION + ) responses = [] for request in requests: if model_type == "mme": diff --git a/apps/fastapi/src/app/services/inference.py b/apps/fastapi/src/app/services/inference.py index 78579cd..c092671 100644 --- a/apps/fastapi/src/app/services/inference.py +++ b/apps/fastapi/src/app/services/inference.py @@ -1,4 +1,3 @@ -import asyncio import json import logging from typing import Any, Dict, List @@ -11,15 +10,17 @@ logger = logging.getLogger(__name__) -IEDB_API_URL_CLASSI = "http://tools-cluster-interface.iedb.org/tools_api/mhci/" -IEDB_API_URL_CLASSII = "http://tools-cluster-interface.iedb.org/tools_api/mhcii/" +CLASSI_URL = f"{settings.EC2_TOOLS_API_URL}/api/v1/prediction/netmhcpan/" +CLASSII_URL = f"{settings.EC2_TOOLS_API_URL}/api/v1/prediction/netmhciipan/" def get_sagemaker_predictions(requests, endpoint_name, model_name, model_type="mme"): """ Pass preprocessed requests to the specific endpoint (mme or single) """ - sagemaker_runtime = boto3.client("sagemaker-runtime", region_name=settings.REGION) + sagemaker_runtime = boto3.client( + "sagemaker-runtime", region_name=settings.AWS_REGION + ) responses = [] for request in requests: if model_type == "mme": @@ -41,96 +42,39 @@ def get_sagemaker_predictions(requests, endpoint_name, model_name, model_type="m return responses +timeout = httpx.Timeout(10.0, read=60.0) # 10 seconds connect, 60 seconds read timeout + + async def run_netmhci_binding_affinity_classI( - peptides: List[str], alleles: List[str], method: str = "netmhcpan-4.1" + peptides: List[str], alleles: List[str] ) -> List[Dict[str, Any]]: """ - Uses IEDB API to generate binding affinity for each peptide and HLA interaction. + Calls the NetMHCpan API with peptides and alleles to get binding affinity results. Args: - peptides (list): A list of peptide sequences. - alleles (list): A list of HLA alleles for which to make predictions. - method (str): Prediction method to use. + peptides (List[str]): List of peptide sequences. + alleles (List[str]): List of HLA alleles for predictions. Returns: - list: A list of dictionaries containing the binding affinity results or errors. + List[Dict[str, Any]]: List of prediction results. """ - results = [] - peptides_by_length = {} - - # Group peptides by their length - for peptide in peptides: - length = len(peptide) - peptides_by_length.setdefault(length, []).append(peptide) - - async with httpx.AsyncClient() as client: - for allele in alleles: - for length, peptides_subset in peptides_by_length.items(): - sequence_text = "\n".join( - [ - f">peptide{i}\n{peptide}" - for i, peptide in enumerate(peptides_subset) - ] - ) - - payload = { - "method": method, - "sequence_text": sequence_text, - "allele": allele, - "length": str(length), - "species": "human", - } - - retries = 0 - max_retries = 5 - while retries < max_retries: - try: - response = await client.post(IEDB_API_URL_CLASSI, data=payload) - - # Handle 403 and 500 errors with retry logic - if response.status_code in [403, 500]: - retries += 1 - sleep_time = 2**retries # Exponential backoff - logger.error( - f"Server error {response.status_code}. Retrying in {sleep_time} seconds..." - ) - await asyncio.sleep(sleep_time) - else: - response.raise_for_status() # Raise error for any other issues - results.append( - { - "allele": allele, - "length": length, - "peptides": peptides_subset, - "result": response.text, - } - ) - logger.info( - f"Successfully retrieved data for allele {allele} and length {length}." - ) - break # Break loop on success - except httpx.RequestError as e: - if retries == max_retries: - logger.error( - f"Max retries reached for allele {allele} and length {length}: {e}" - ) - results.append( - { - "allele": allele, - "length": length, - "peptides": peptides_subset, - "error": str(e), - } - ) - else: - retries += 1 - sleep_time = 2**retries - logger.error( - f"Request error. Retrying in {sleep_time} seconds for allele {allele} and length {length}: {e}" - ) - await asyncio.sleep(sleep_time) - except Exception as e: - logger.error(f"Unexpected error: {e}") - break - - return results + payload = {"peptides": peptides, "alleles": alleles} + + try: + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.post(CLASSI_URL, json=payload) + response.raise_for_status() + results = response.json() + return results + except (httpx.RequestError, httpx.HTTPStatusError) as e: + logger.error(f"Request error: {e}") + # Return an error in the same structure as successful results + return [{"peptides": peptides, "error": str(e)}] # List with error + except httpx.HTTPStatusError as e: + logger.error( + f"HTTP error: {e.response.status_code}, details: {e.response.text}" + ) + return {"error": e.response.text} + except Exception as e: + logger.error(f"Unexpected error: {e}") + return {"error": str(e)} diff --git a/apps/fastapi/src/app/services/postprocess.py b/apps/fastapi/src/app/services/postprocess.py index 07f33e2..40c110b 100644 --- a/apps/fastapi/src/app/services/postprocess.py +++ b/apps/fastapi/src/app/services/postprocess.py @@ -9,7 +9,7 @@ from app.core.config import settings from app.core.utils import generate_csv_key, read_s3_csv, upload_csv_to_s3 -from app.crud import crud_mhc_i_prediction +from app.crud.crud_mhc_i_prediction import crud_mhc_i_prediction from app.schemas.conformational_b_prediction import PredictionResult from app.schemas.linear_b_prediction import LBPredictionResult from app.schemas.mhc_i_prediction import MhcIPredictionResult @@ -81,29 +81,48 @@ async def process_classI_results( results: List[Dict[str, Any]], ) -> List[MhcIPredictionResult]: """ - Processes the Class I results returned by the IEDB API. + Processes the Class I results returned by the IEDB API or similar prediction tool. + - Extracts relevant peptide, allele, and affinity data. + - Formats the results and calculates the best binding affinity. """ peptide_data = {} + # Check if there's an error in the results + if isinstance(results, dict) and "error" in results: + logger.error(f"Error in results: {results['error']}") + raise HTTPException(status_code=400, detail=results["error"]) + for res in results: if "result" in res: - try: - df = pd.read_csv(StringIO(res["result"]), sep="\t") - if {"peptide", "allele", "ic50"}.issubset(df.columns): - for _, row in df.iterrows(): - peptide = row["peptide"] - allele = row["allele"] - ic50 = row["ic50"] - peptide_data.setdefault(peptide, {"binding_affinities": []}) - peptide_data[peptide]["binding_affinities"].append( - (allele, float(ic50)) - ) - else: - logger.warning(f"Unexpected columns in API response: {df.columns}") - except pd.errors.EmptyDataError: - logger.warning( - f"Received empty data from API for allele {res['allele']} and length {res['length']}." - ) + # Check if the result is a list of dictionaries instead of a string + if isinstance(res["result"], list): + # Handle the case where the result is a list of dictionaries + df = pd.DataFrame( + res["result"] + ) # Directly convert the list of dicts to a DataFrame + else: + try: + # Handle the case where the result is a string + df = pd.read_csv(StringIO(res["result"]), sep="\t") + except Exception as e: + logger.error(f"Error reading data for result {res}: {str(e)}") + raise HTTPException( + status_code=500, detail="Error processing results." + ) + + # Check for required columns + if {"peptide", "allele", "affinity"}.issubset(df.columns): + for _, row in df.iterrows(): + peptide = row["peptide"] + allele = row["allele"] + affinity = row["affinity"] # Using 'affinity' instead of 'ic50' + + peptide_data.setdefault(peptide, {"binding_affinities": []}) + peptide_data[peptide]["binding_affinities"].append( + (allele, float(affinity)) + ) + else: + logger.warning(f"Unexpected columns in API response: {df.columns}") else: # Handle errors if any for peptide in res["peptides"]: @@ -114,15 +133,36 @@ async def process_classI_results( processed_results = [] for peptide, data in peptide_data.items(): binding_affinities = data.get("binding_affinities", []) - binding_affinity_str = "|".join( - [f"{allele}={ic50} nM" for allele, ic50 in binding_affinities] - ) + + # Ensure binding_affinities is a list and each element is a tuple of (allele, affinity) + if isinstance(binding_affinities, list) and all( + isinstance(x, tuple) and len(x) == 2 for x in binding_affinities + ): + # Log for debugging + logger.debug( + f"Binding affinities for peptide {peptide}: {binding_affinities}" + ) + + binding_affinity_str = "|".join( + [ + f"{allele}={affinity:.2f} nM" + for allele, affinity in binding_affinities + ] + ) + else: + logger.warning( + f"Binding affinities for peptide {peptide} is not formatted correctly: {binding_affinities}" + ) + binding_affinity_str = "" + + # Determine the best binding affinity (minimum affinity value) best_binding_affinity = ( f"{min(binding_affinities, key=lambda x: x[1])}" if binding_affinities else "" ) + # Append the formatted result for this peptide processed_results.append( MhcIPredictionResult( Peptide_Sequence=peptide, diff --git a/apps/nextjs/src/components/predictions/tables/mhc-peptide-dialog-cell.tsx b/apps/nextjs/src/components/predictions/tables/mhc-peptide-dialog-cell.tsx index fda1af6..c808f23 100644 --- a/apps/nextjs/src/components/predictions/tables/mhc-peptide-dialog-cell.tsx +++ b/apps/nextjs/src/components/predictions/tables/mhc-peptide-dialog-cell.tsx @@ -82,20 +82,22 @@ export function MhcPeptideDialogCell({ rowData }: PeptideDialogProps) { {Peptide_Sequence} -

-
-
TCR Recognition:
-
{tcrRecognition}
-
-
-
Best Binding Affinity:
-
{Best_Binding_Affinity}
-
-
-
Best pMHC Stability:
-
{Best_pMHC_Stability}
-
-
+
+
+
+
TCR Recognition:
+
{tcrRecognition}
+
+
+
Best Binding Affinity:
+
{Best_Binding_Affinity}
+
+
+
Best pMHC Stability:
+
{Best_pMHC_Stability}
+
+
+
diff --git a/apps/tools-fastapi/.dockerignore b/apps/tools-fastapi/.dockerignore new file mode 100644 index 0000000..ad31828 --- /dev/null +++ b/apps/tools-fastapi/.dockerignore @@ -0,0 +1,37 @@ +# Include any files or directories that you don't want to be copied to your +# container here (e.g., local build artifacts, temporary files, etc.). +# +# For more help, visit the .dockerignore file reference guide at +# https://docs.docker.com/go/build-context-dockerignore/ + +**/.DS_Store +**/__pycache__ +**/.venv +**/.classpath +**/.dockerignore +**/.env +**/.git +**/.gitignore +**/.project +**/.settings +**/.toolstarget +**/.vs +**/.vscode +**/*.*proj.user +**/*.dbmdl +**/*.jfm +**/bin +**/charts +**/docker-compose* +**/compose* +**/Dockerfile* +**/node_modules +**/npm-debug.log +**/obj +**/secrets.dev.yaml +**/values.dev.yaml +LICENSE +README.md + +# ruff +.ruff_cache/ diff --git a/apps/tools-fastapi/.env.example b/apps/tools-fastapi/.env.example new file mode 100644 index 0000000..6e3f190 --- /dev/null +++ b/apps/tools-fastapi/.env.example @@ -0,0 +1 @@ +BACKEND_CORS_ORIGINS="http://localhost:3000" diff --git a/apps/tools-fastapi/.gitignore b/apps/tools-fastapi/.gitignore new file mode 100644 index 0000000..f7a833a --- /dev/null +++ b/apps/tools-fastapi/.gitignore @@ -0,0 +1,179 @@ +# Created by https://www.toptal.com/developers/gitignore/api/python +# Edit at https://www.toptal.com/developers/gitignore?templates=python + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +.env.local +.env.stage +.env.prod + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +# poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +# End of https://www.toptal.com/developers/gitignore/api/python diff --git a/apps/tools-fastapi/.python-version b/apps/tools-fastapi/.python-version new file mode 100644 index 0000000..e4fba21 --- /dev/null +++ b/apps/tools-fastapi/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/apps/tools-fastapi/.vscode/launch.json b/apps/tools-fastapi/.vscode/launch.json new file mode 100644 index 0000000..0623662 --- /dev/null +++ b/apps/tools-fastapi/.vscode/launch.json @@ -0,0 +1,13 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Tools FastAPI", + "type": "debugpy", + "request": "launch", + "module": "uvicorn", + "args": ["api.main:app", "--reload"], + "jinja": true + } + ] +} diff --git a/apps/tools-fastapi/.vscode/settings.json b/apps/tools-fastapi/.vscode/settings.json new file mode 100644 index 0000000..d969f96 --- /dev/null +++ b/apps/tools-fastapi/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "python.testing.pytestArgs": ["tests"], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} diff --git a/apps/tools-fastapi/netMHCIIpan-4.3.readme b/apps/tools-fastapi/netMHCIIpan-4.3.readme new file mode 100644 index 0000000..5f131bc --- /dev/null +++ b/apps/tools-fastapi/netMHCIIpan-4.3.readme @@ -0,0 +1,142 @@ +NetMHCIIpan 4.3 INSTALLATION INSTRUCTIONS + + + DESCRIPTION + + The NetMHCIIpan 4.3 software predicts binding of peptides to MHC class II + molecules. The predictions are available for all three human MHC class II + isotypes: HLA-DR, HLA-DP and HLA-DQ, as well as for mouse molecules (H-2) and + cattle molecules (BoLA). + Version 4.3 is a new update to NetMHCIIpan-4.2 retrained on an extended set of + HLA-DP EL data as well as new data for HLA-DR and BoLA-DRB3. Further, the method + uses an updated version of the NNAlign_MA framework, allowing for prediction of + inverted peptide binders. + + The 4.3 method is described in detail in the following articles: + + Accurate prediction of HLA class II antigen presentation across all loci using + tailored data acquisition and refined machine learning + Jonas B. Nilsson, Saghar Kaabinejadian, Hooman Yari, Michel G. D. Kester, + Peter van Balen, William H. Hildebrand and Morten Nielsen + Science Advances, 24 Nov 2023. https://www.science.org/doi/10.1126/sciadv.adj6367 + + Previous version (4.2): + Machine learning reveals limited contribution of trans-only encoded variants to + the HLA-DQ immunopeptidome + Jonas Birkelund Nilsson, Saghar Kaabinejadian, Hooman Yari, Bjoern Peters, + Carolina Barra, Loren Gragert, William Hildebrand and Morten Nielsen + Communications Biology, 21 April 2023. https://doi.org/10.1038/s42003-023-04749-7 + + Previous version (4.1): + Accurate MHC Motif Deconvolution of Immunopeptidomics Data Reveals a Significant + Contribution of DRB3, 4 and 5 to the Total DR Immunopeptidome + Kaabinejadian S, Barra C, Alvarez B, Yari H, Hildebrand WH and Nielsen M + Front. Immunol. 13:835454. Published: 26 January 2022. doi: 10.3389/fimmu.2022.835454 + + Previous version (4.0): + Improved prediction of MHC II antigen presentation through integration and motif + deconvolution of mass spectrometry MHC eluted ligand data. + Reynisson B, Barra C, Kaabinejadian S, Hildebrand WH, Peters B, Nielsen M + J Proteome Res 2020 Apr 30. doi: 10.1021/acs.jproteome.9b00874. + + Previous version (3.2): + Improved methods for predicting peptide binding affinity to MHC class II molecules. + Jensen KK, Andreatta M, Marcatili P, Buus S, Greenbaum JA, Yan Z, Sette A, + Peters B, Nielsen M. + Immunology. 2018 Jan 6. doi: 10.1111/imm.12889. + + The previous version (3.1): + + Accurate pan-specific prediction of peptide-MHC class II binding + affinity with improved binding core identification. + Andreatta M, Karosiene E, Rasmussen M, Stryhn A, Buus S and Nielsen M. + Immunogenetics, Epub ahead of print, PubMed 26416257, Sep 29, 2015. + + The previous version (3.0): + + NetMHCIIpan-3.0, a common pan-specific MHC class II prediction method + including all three human MHC class II isotypes, HLA-DR, HLA-DP and HLA-DQ + Karosiene E, Rasmussen M, Blicher T, Lund O, Buus S, and Nielsen M. + Immunogenetics Oct;65(10):711-24, 2013. + + More information about the method can be found at: + + https://services.healthtech.dtu.dk/services/NetMHCIIpan-4.3/ + + DOWNLOAD + + The NetMHCIIpan 4.3 software is a property of DTU Section for Bioinformatics + It may be downloaded only by special agreement. For + academic users there is a download site at: + + + https://services.healthtech.dtu.dk/software.php + + Other users are requested to contact health-software@dtu.dk. + + + PRE-INSTALLATION + + netMHCIIpan 4.3 currently runs under Darwin (macOS x86_64 or arm64) and Linux (several + vendors). The package consists of two files: + + netMHCIIpan-4.3.readme this file + netMHCIIpan-4.3..tar.Z compressed TAR archive + + where 'unix' is the UNIX platform on which you are about to install. + + INSTALLATION + + 1. Uncompress and untar the package: + + tar -xvf netMHCIIpan-4.3..tar.gz + + This will produce a directory 'netMHCIIpan-4.3. + + 2. In the 'netMHCIIpan-4.3' directory edit the script 'netMHCIIpan': + + a. At the top of the file locate the part labelled "GENERAL SETTINGS: + CUSTOMIZE TO YOUR SITE" and set the 'NMHOME' variable to the full + path to the 'netMHCIIpan-4.3' directory on your system; + + b. Set TMPDIR to the full path to the tmp directory of you choice (must + be user writable); + + 3. In the 'netMHCIIpan-4.3/test' directory test the software: + + > ../netMHCIIpan -inptype 1 -f example.pep > example.pep.out + > ../netMHCIIpan -inptype 1 -f example.pep_context -context > example.pep_context.out + > ../netMHCIIpan -inptype 0 -f example.fsa > example.fsa.out + > ../netMHCIIpan -inptype 0 -f example.fsa -context -termAcon > example.fsa_context.out + > ../netMHCIIpan -f example.fsa -s -u > example.fsa.sorted.out + > ../netMHCIIpan -f example.fsa -hlaseqA DQA1_0101.fsa -hlaseq DQB1_0201.fsa > example.fsa_hlaseq_A+B.out + + The resulting "*.myout" files should be identical to the corresponding + "*.out" files provided in the package. + + 4. Finish the installation: + + a. Copy or link the 'netMHCIIpan' script to a directory in the users' + path. + + b. Copy the 'netMHCIIpan.1' file to a location in your manual system. + If you need a compiled version try running: + + man -d netMHCIIpan.1 |compress >netMHCIIpan.Z + + or: + + neqn netMHCIIpan.1 |tbl |nroff -man |col |compress >netMHCIIpan.Z + + 5. Enjoy ... + + + PROBLEMS + + Contact health-software@dtu.dk, in case of problems. + + Questions on the _scientific_ aspects of the NetMHCIIpan method should be + sent to dr Morten Nielsen, morni@dtu.dk. + + 9 Oct 2023 + M. Nielsen diff --git a/apps/tools-fastapi/netMHCpan-4.1.readme b/apps/tools-fastapi/netMHCpan-4.1.readme new file mode 100644 index 0000000..7c6c6a1 --- /dev/null +++ b/apps/tools-fastapi/netMHCpan-4.1.readme @@ -0,0 +1,186 @@ +NetMHCpan 4.1 INSTALLATION INSTRUCTIONS + + + DESCRIPTION + + The NetMHCpan 4.1 software predicts binding of peptides to any + known MHC molecule using artificial neural networks (ANNs). The + method is trained on a combination of more than 850,000 quantitative + Binding Affinity (BA) and Mass-Spectrometry Eluted Ligands (EL) + peptides. The BA data covers 201 MHC molecules from human (HLA-A, B, + C, E), mouse (H-2), cattle (BoLA), primates (Patr, Mamu, Gogo), swine + (SLA) and equine (Eqca). The EL data covers 289 MHC molecules from + human (HLA-A, B, C, E), mouse (H-2), cattle (BoLA), primates (Patr, + Mamu, Gogo), swine (SLA), equine (Eqca) and dog (DLA). Furthermore, + the user can obtain predictions to any custom MHC class I molecule + by uploading a full length MHC protein sequence. Predictions can be + made for peptides of any length. + + Version 4.1 is described in the following publication: + + NetMHCpan-4.1 and NetMHCIIpan-4.0: Improved predictions of MHC antigen + presentation by concurrent motif deconvolution and integration of MS + MHC eluted ligand data. + Birkir Reynisson, Bruno Alvarez, Sinu Paul, Bjoern Peters and Morten Nielsen + Submitted, 2020 + + + Earlier Version. 4.0: + + NetMHCpan-4.0: Improved Peptide-MHC Class I Interaction Predictions + Integrating Eluted Ligand and Peptide Binding Affinity Data + + Vanessa Jurtz, Sinu Paul, Massimo Andreatta, Paolo Marcatili, Bjoern Peters, + and Morten Nielsen + The Journal of Immunology (2017) ji1700893; DOI: 10.4049/jimmunol.1700893 + + The original paper: + NetMHCpan, a Method for Quantitative Predictions of Peptide Binding to Any + HLA-A and -B Locus Protein of Known Sequence. + Morten Nielsen et al. + PLoS ONE 2(8): e796. doi:10.1371/journal.pone.0000796, 2007. + + More information about the method, including instructions, guidelines, and + output description can be found at: + + http://www.cbs.dtu.dk/services/NetMHCpan/ + + + DOWNLOAD + + The netMHCpan 4.1 software package is a property of Department of + Health Technology, Section of Bioinformatics The Technical University + of Denmark It may be downloaded only by special agreement. + For academic users there is a download site at: + + http://www.cbs.dtu.dk/cgi-bin/nph-sw_request?netMHCpan + + Other users are requested to contact software@cbs.dtu.dk. + + PRE-INSTALLATION + + netMHCpan 4.1 currently runs under Darwin (MacOSX), and Linux. + The package consists of two files: + + netMHCpan-4.1.readme this file + netMHCpan-4.1..tar.gz compressed TAR archive + + where 'unix' is the UNIX platform on which you are about to install. After + installation the software will occupy less than 45 MB of diskspace. + + INSTALLATION + + 1. Uncompress and untar the package: + + cat netMHCpan-4.1..tar.gz | uncompress | tar xvf - + + This will produce a directory 'netMHCpan-4.1'. + + 2. From the CBS website download the file: + + https://services.healthtech.dtu.dk/services/NetMHCpan-4.1/data.tar.gz + + Put it in the 'netMHCpan-4.1' directory and + then untar it: + + tar -xvf data.tar.gz + + This will produce a directory 'data'. It is necessary for the + NetMHCpan 4.1 software to operate; once it is installed you may delete + the 'data.tar.gz' file; it will not be needed. + + 2. In the 'netMHCpan-4.1' directory edit the script 'netMHCpan': + + a. At the top of the file locate the part labelled "GENERAL SETTINGS: + CUSTOMIZE TO YOUR SITE" and set the 'NMHOME' variable to the full + path to the 'netMHCpan-4.1' directory on your system; + + b. Set TMPDIR to the full path to the temporary directory of you choice. It must + be user-writable. You may for example set it to $NMHOME/tmp (and create + the tmp folder in the netMHCpan-4.1 directory). + + 3. In the 'netMHCpan-4.1/test' directory test the software: + + > ../netMHCpan -p test.pep > test.pep.myout + > ../netMHCpan test.fsa > test.fsa.myout + > ../netMHCpan -hlaseq B0702.fsa -p test.pep > test.pep_userMHC.myout + > ../netMHCpan -p test.pep -BA > test.pep_BA.out + > ../netMHCpan -p test.pep -BA -xls -a HLA-A01:01,HLA-A02:01 -xlsfile my_NetMHCpan_out.xls + + The resulting ".myout" files should not differ from the corresponding + ".out" files provided in the package other than in the directory names + and small rounding errors. + + 4. Finish the installation: + + a. Copy or link the 'netMHCpan' file to a directory in the users' path. + + b. Copy the 'netMHCpan.1' file to a location in your manual system. If + you need a compiled version try running: + + man -d netMHCpan.1 | compress >netMHCpan.Z + + or: + + neqn netMHCpan.1 | tbl | nroff -man | col | compress >netMHCpan.Z + + 5. Enjoy ... + +PROBLEMS + + Contact packages@cbs.dtu.dk in case of problems. + + Questions on the scientific aspects of the netMHCpan method should be sent + to Dr. Morten Nielsen, morni@dtu.dk. + + DTU, 24 April 2020 + M. Nielsen + +-------------------------------------------------------------------------------- + +#! /bin/tcsh -f + +# This the main NetMHCpan 4.1 script. It only acts as the frontend to the +# software proper, a compiled binary. +# +# VERSION: 2019 Dec 9 launch +# + +############################################################################### +# GENERAL SETTINGS: CUSTOMIZE TO YOUR SITE +############################################################################### + +# full path to the NetMHCpan 4.0 directory (mandatory) +setenv NMHOME /opt/netMHCpan-4.1 + +# determine where to store temporary files (must be writable to all users) + +if ( ${?TMPDIR} == 0 ) then + setenv TMPDIR $NMHOME/tmp +endif + +echo "Using TMPDIR: " $TMPDIR + +# determine platform (do not change this unless you don't have 'uname'!) +setenv UNIX `uname -s` +setenv AR `uname -m` + +############################################################################### +# NOTHING SHOULD NEED CHANGING BELOW THIS LINE! +############################################################################### + +# other settings +set PLATFORM = `echo $UNIX $AR | awk '{print $1"_"$2}'` +setenv NETMHCpan $NMHOME/$PLATFORM +setenv DTUIBSWWW www +setenv NetMHCpanWWWPATH /services/NetMHCpan/tmp/ +setenv NetMHCpanWWWDIR /usr/opt/www/pub/CBS/services/NetMHCpan/tmp + +# main ======================================================================== +if ( -x $NETMHCpan/bin/netMHCpan ) then + $NETMHCpan/bin/netMHCpan $* +else + echo netMHCpan: no binaries found for $PLATFORM $NETMHCpan/bin/netMHCpan +endif + +# end of script =============================================================== diff --git a/apps/tools-fastapi/package.json b/apps/tools-fastapi/package.json new file mode 100644 index 0000000..4fa5d29 --- /dev/null +++ b/apps/tools-fastapi/package.json @@ -0,0 +1,8 @@ +{ + "name": "prediction-tools-fastapi", + "version": "0.1.0", + "scripts": { + "dev": "uv run fastapi dev src/app/main.py", + "prod": "uv run fastapi run src/app/main.py --port 8000" + } +} diff --git a/apps/tools-fastapi/pyproject.toml b/apps/tools-fastapi/pyproject.toml new file mode 100644 index 0000000..971fe1e --- /dev/null +++ b/apps/tools-fastapi/pyproject.toml @@ -0,0 +1,29 @@ +[project] +name = "prediction-tools-fastapi" +version = "0.1.0" +description = "FastAPI for B-cell and T-cell epitope prediction tools." +authors = [ + {name = "trevorpfiz", email = "elektrikspark@gmail.com"}, + {name = "zacharypfiz", email = "ztpfizcode@gmail.com"}, +] +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "fastapi[standard]>=0.115.0", + "httpx>=0.27.2", + "mhctools>=1.9.0", + "pandas>=2.2.3", + "pydantic-settings>=2.5.2", + "requests>=2.32.3", +] + +[tool.uv] +dev-dependencies = [ + "pylint-pydantic>=0.3.2", + "pytest>=8.3.3", + "python-dotenv>=1.0.1", + "ruff>=0.6.8", +] + +[tool.uv.sources] +mhctools = { git = "https://github.com/hammerlab/mhctools.git" } diff --git a/apps/tools-fastapi/requirements.txt b/apps/tools-fastapi/requirements.txt new file mode 100644 index 0000000..a71c2b3 --- /dev/null +++ b/apps/tools-fastapi/requirements.txt @@ -0,0 +1,311 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile pyproject.toml -o requirements.txt +absl-py==2.1.0 + # via + # keras + # tensorboard + # tensorflow +annotated-types==0.7.0 + # via pydantic +anyio==4.6.2.post1 + # via + # httpx + # starlette + # watchfiles +appdirs==1.4.4 + # via + # datacache + # mhcflurry +astroid==2.15.8 + # via pylint +astunparse==1.6.3 + # via tensorflow +biopython==1.84 + # via varcode +certifi==2024.8.30 + # via + # httpcore + # httpx + # requests +charset-normalizer==3.4.0 + # via requests +click==8.1.7 + # via + # typer + # uvicorn +datacache==1.4.1 + # via pyensembl +dill==0.3.9 + # via pylint +dnspython==2.7.0 + # via email-validator +email-validator==2.2.0 + # via fastapi +fastapi==0.115.2 + # via prediction-tools-fastapi (pyproject.toml) +fastapi-cli==0.0.5 + # via fastapi +flatbuffers==24.3.25 + # via tensorflow +gast==0.6.0 + # via tensorflow +google-pasta==0.2.0 + # via tensorflow +grpcio==1.67.0 + # via + # tensorboard + # tensorflow +gtfparse==2.5.0 + # via pyensembl +h11==0.14.0 + # via + # httpcore + # uvicorn +h5py==3.12.1 + # via + # keras + # tensorflow +httpcore==1.0.6 + # via httpx +httptools==0.6.4 + # via uvicorn +httpx==0.27.2 + # via + # prediction-tools-fastapi (pyproject.toml) + # fastapi +idna==3.10 + # via + # anyio + # email-validator + # httpx + # requests +isort==5.13.2 + # via pylint +jinja2==3.1.4 + # via fastapi +joblib==1.4.2 + # via scikit-learn +keras==3.6.0 + # via tensorflow +lazy-object-proxy==1.10.0 + # via astroid +libclang==18.1.1 + # via tensorflow +markdown==3.7 + # via tensorboard +markdown-it-py==3.0.0 + # via rich +markupsafe==3.0.2 + # via + # jinja2 + # werkzeug +mccabe==0.7.0 + # via pylint +mdurl==0.1.2 + # via markdown-it-py +memoized-property==1.0.3 + # via + # pyensembl + # varcode +mhcflurry==2.1.4 + # via mhctools +mhcgnomes==1.8.6 + # via mhcflurry +mhcnames==0.4.8 + # via mhctools +mhctools==1.9.0 + # via prediction-tools-fastapi (pyproject.toml) +ml-dtypes==0.4.1 + # via + # keras + # tensorflow +mock==5.1.0 + # via datacache +namex==0.0.8 + # via keras +numpy==1.26.4 + # via + # biopython + # h5py + # keras + # mhcgnomes + # mhctools + # ml-dtypes + # pandas + # pyarrow + # scikit-learn + # scipy + # tensorboard + # tensorflow + # varcode +opt-einsum==3.4.0 + # via tensorflow +optree==0.13.0 + # via keras +packaging==24.1 + # via + # keras + # tensorboard + # tensorflow +pandas==2.2.3 + # via + # prediction-tools-fastapi (pyproject.toml) + # datacache + # mhcflurry + # mhcgnomes + # mhctools + # sercol + # varcode +platformdirs==4.3.6 + # via pylint +polars==0.20.31 + # via gtfparse +progressbar33==2.4 + # via datacache +protobuf==4.25.5 + # via + # tensorboard + # tensorflow +pyarrow==14.0.2 + # via gtfparse +pydantic==2.9.2 + # via + # fastapi + # pydantic-settings +pydantic-core==2.23.4 + # via pydantic +pydantic-settings==2.6.0 + # via prediction-tools-fastapi (pyproject.toml) +pyensembl==2.3.13 + # via + # mhctools + # varcode +pygments==2.18.0 + # via rich +pylint==2.17.7 + # via pyensembl +python-dateutil==2.9.0.post0 + # via pandas +python-dotenv==1.0.1 + # via + # pydantic-settings + # uvicorn +python-multipart==0.0.12 + # via fastapi +pytz==2024.2 + # via pandas +pyvcf3==1.0.3 + # via varcode +pyyaml==6.0.2 + # via + # mhcflurry + # mhcgnomes + # uvicorn +requests==2.32.3 + # via + # prediction-tools-fastapi (pyproject.toml) + # datacache + # tensorflow +rich==13.9.2 + # via + # keras + # typer +scikit-learn==1.5.2 + # via mhcflurry +scipy==1.14.1 + # via scikit-learn +sercol==1.0.0 + # via + # mhctools + # varcode +serializable==0.4.1 + # via + # mhcgnomes + # pyensembl + # sercol + # varcode +setuptools==75.2.0 + # via + # pyvcf3 + # tensorboard + # tensorflow +shellingham==1.5.4 + # via typer +simplejson==3.19.3 + # via + # sercol + # serializable +six==1.16.0 + # via + # astunparse + # google-pasta + # mhcflurry + # mhcnames + # python-dateutil + # tensorboard + # tensorflow +sniffio==1.3.1 + # via + # anyio + # httpx +starlette==0.40.0 + # via fastapi +tensorboard==2.17.1 + # via tensorflow +tensorboard-data-server==0.7.2 + # via tensorboard +tensorflow==2.17.0 + # via + # mhcflurry + # tf-keras +termcolor==2.5.0 + # via tensorflow +tf-keras==2.17.0 + # via mhcflurry +threadpoolctl==3.5.0 + # via scikit-learn +tinytimer==0.0.0 + # via pyensembl +tomlkit==0.13.2 + # via pylint +tqdm==4.66.5 + # via mhcflurry +typechecks==0.1.0 + # via + # datacache + # pyensembl + # serializable +typer==0.12.5 + # via fastapi-cli +typing-extensions==4.12.2 + # via + # fastapi + # optree + # pydantic + # pydantic-core + # tensorflow + # typer +tzdata==2024.2 + # via pandas +urllib3==2.2.3 + # via requests +uvicorn==0.32.0 + # via + # fastapi + # fastapi-cli +uvloop==0.21.0 + # via uvicorn +varcode==1.2.1 + # via mhctools +watchfiles==0.24.0 + # via uvicorn +websockets==13.1 + # via uvicorn +werkzeug==3.0.4 + # via tensorboard +wheel==0.44.0 + # via astunparse +wrapt==1.16.0 + # via + # astroid + # tensorflow diff --git a/apps/tools-fastapi/src/app/__init__.py b/apps/tools-fastapi/src/app/__init__.py new file mode 100644 index 0000000..3dc1f76 --- /dev/null +++ b/apps/tools-fastapi/src/app/__init__.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/apps/tools-fastapi/src/app/api/__init__.py b/apps/tools-fastapi/src/app/api/__init__.py new file mode 100644 index 0000000..3dc1f76 --- /dev/null +++ b/apps/tools-fastapi/src/app/api/__init__.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/apps/tools-fastapi/src/app/api/api_v1/__init__.py b/apps/tools-fastapi/src/app/api/api_v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/tools-fastapi/src/app/api/api_v1/api.py b/apps/tools-fastapi/src/app/api/api_v1/api.py new file mode 100644 index 0000000..a8836a7 --- /dev/null +++ b/apps/tools-fastapi/src/app/api/api_v1/api.py @@ -0,0 +1,13 @@ +from fastapi import APIRouter + +from app.api.api_v1.endpoints import prediction + +api_router = APIRouter() + + +api_router.include_router( + prediction.router, + prefix="/prediction", + tags=["prediction"], + responses={404: {"description": "Not found"}}, +) diff --git a/apps/tools-fastapi/src/app/api/api_v1/endpoints/__init__.py b/apps/tools-fastapi/src/app/api/api_v1/endpoints/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/tools-fastapi/src/app/api/api_v1/endpoints/prediction.py b/apps/tools-fastapi/src/app/api/api_v1/endpoints/prediction.py new file mode 100644 index 0000000..921630e --- /dev/null +++ b/apps/tools-fastapi/src/app/api/api_v1/endpoints/prediction.py @@ -0,0 +1,46 @@ +from typing import List + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from app.services.inference import run_binding_predictions + +router = APIRouter() + + +# Input data model +class PredictionRequest(BaseModel): + peptides: List[str] + alleles: List[str] + + +@router.post("/netmhcpan/") +async def run_netmhcpan(request: PredictionRequest): + """ + Endpoint to run NetMHCpan predictions on input peptides and alleles. + """ + try: + predictions = await run_binding_predictions( + peptides=request.peptides, + alleles=request.alleles, + predictor_type="netmhcpan", + ) + return predictions + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/netmhciipan/") +async def run_netmhciipan(request: PredictionRequest): + """ + Endpoint to run NetMHCIIpan predictions on input peptides and alleles. + """ + try: + predictions = await run_binding_predictions( + peptides=request.peptides, + alleles=request.alleles, + predictor_type="netmhciipan", + ) + return predictions + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/apps/tools-fastapi/src/app/core/__init__.py b/apps/tools-fastapi/src/app/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/tools-fastapi/src/app/core/config.py b/apps/tools-fastapi/src/app/core/config.py new file mode 100644 index 0000000..747ce42 --- /dev/null +++ b/apps/tools-fastapi/src/app/core/config.py @@ -0,0 +1,49 @@ +import logging +from typing import List, Union + +from dotenv import load_dotenv +from pydantic import AnyHttpUrl, Field, field_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + +log_format = logging.Formatter("%(asctime)s : %(levelname)s - %(message)s") + +# root logger +root_logger = logging.getLogger() +root_logger.setLevel(logging.INFO) + +# standard stream handler +stream_handler = logging.StreamHandler() +stream_handler.setFormatter(log_format) +root_logger.addHandler(stream_handler) + +logger = logging.getLogger(__name__) + +load_dotenv() + + +class Settings(BaseSettings): + ENV: str = Field(default="", env="ENV") + BACKEND_CORS_ORIGINS: Union[List[AnyHttpUrl], List[str]] = Field( + default=["http://localhost:3000"], env="BACKEND_CORS_ORIGINS" + ) + + # Validator to parse both comma-separated strings and lists from .env file + @field_validator("BACKEND_CORS_ORIGINS", mode="before") + def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str]: + if isinstance(v, str) and not v.startswith("["): + # Split the string by commas and remove whitespace + return [i.strip() for i in v.split(",")] + elif isinstance(v, list): + return v + raise ValueError("Invalid format for BACKEND_CORS_ORIGINS") + + # Project details + API_VERSION: str = "/api/v1" + PROJECT_NAME: str = "Epitope Prediction Tools FastAPI" + PROJECT_DESCRIPTION: str = "A simple FastAPI app" + + # Pydantic configuration to load environment variables from .env + model_config = SettingsConfigDict(env_file=".env") + + +settings = Settings() diff --git a/apps/tools-fastapi/src/app/core/utils.py b/apps/tools-fastapi/src/app/core/utils.py new file mode 100644 index 0000000..eea436a --- /dev/null +++ b/apps/tools-fastapi/src/app/core/utils.py @@ -0,0 +1,3 @@ +import logging + +logger = logging.getLogger(__name__) diff --git a/apps/tools-fastapi/src/app/main.py b/apps/tools-fastapi/src/app/main.py new file mode 100644 index 0000000..e555879 --- /dev/null +++ b/apps/tools-fastapi/src/app/main.py @@ -0,0 +1,65 @@ +import logging + +from fastapi import APIRouter, FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.routing import APIRoute + +from app.api.api_v1.api import api_router +from app.core.config import settings + +info_router = APIRouter() + + +@info_router.get("/", status_code=200, include_in_schema=False) +async def info(): + return [{"Status": "API Running"}] + + +def custom_generate_unique_id(route: APIRoute): + """Generates a custom ID when using the TypeScript Generator Client + + Args: + route (APIRoute): The route to be customised + + Returns: + str: tag-route_name, e.g. items-CreateItem + """ + return f"{route.tags[0]}-{route.name}" + + +def get_application(): + _app = FastAPI( + title=settings.PROJECT_NAME, + description=settings.PROJECT_DESCRIPTION, + generate_unique_id_function=custom_generate_unique_id, + openapi_url=f"{settings.API_VERSION}/openapi.json", + root_path=f"/{settings.ENV}" if settings.ENV in ["stage", "prod"] else "", + ) + + if settings.ENV == "" or settings.ENV == "dev": + logger = logging.getLogger("uvicorn") + logger.warning("Running in development mode - allowing CORS for all origins") + _app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + else: + _app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # TODO: Use BACKEND_CORS_ORIGINS + # allow_origins=settings.BACKEND_CORS_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + _app.include_router(api_router, prefix=settings.API_VERSION) + _app.include_router(info_router, tags=[""]) + + return _app + + +app = get_application() diff --git a/apps/tools-fastapi/src/app/services/__init__.py b/apps/tools-fastapi/src/app/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/tools-fastapi/src/app/services/inference.py b/apps/tools-fastapi/src/app/services/inference.py new file mode 100644 index 0000000..886a7d0 --- /dev/null +++ b/apps/tools-fastapi/src/app/services/inference.py @@ -0,0 +1,114 @@ +import logging +from typing import Any, Dict, List + +from mhcnames import normalize_allele_name + +# Import the correct elution score versions of the predictors +from mhctools import NetMHCIIpan43, NetMHCpan41 + +logger = logging.getLogger(__name__) + + +async def run_binding_predictions( + peptides: List[str], alleles: List[str], predictor_type: str +) -> List[Dict[str, Any]]: + """ + Helper function to group peptides by length and run predictions. + + Args: + peptides (List[str]): List of peptides. + alleles (List[str]): List of alleles. + predictor_type (str): Either 'netmhcpan' or 'netmhciipan'. + + Returns: + List[Dict[str, Any]]: List of prediction results for each allele and peptide length. + """ + + # Validate alleles + try: + alleles = [normalize_allele_name(a) for a in alleles] + logger.info(f"Normalized alleles: {alleles}") + except Exception as e: + logger.error(f"Error normalizing alleles: {e}") + return {"error": str(e)} + + results = [] + peptides_by_length = {} + + # Group peptides by their length + for peptide in peptides: + length = len(peptide) + + # Ensure that the length is valid + if length is None or not isinstance(length, int): + logger.error(f"Invalid peptide length for peptide: {peptide}") + continue + + peptides_by_length.setdefault(length, []).append(peptide) + + # Initialize predictor based on the type + predictor = None + if predictor_type == "netmhcpan": + # Use NetMHCpan41_EL for elution score mode + predictor = NetMHCpan41(alleles=alleles) + elif predictor_type == "netmhciipan": + # Use NetMHCIIpan43_EL for elution score mode + predictor = NetMHCIIpan43(alleles=alleles) + + if not predictor: + raise ValueError(f"Unknown predictor type: {predictor_type}") + + for allele in alleles: + for length, peptides_subset in peptides_by_length.items(): + try: + # Ensure valid peptide length before processing + if length is None or not isinstance(length, int): + logger.error( + f"Skipping invalid length for peptides: {peptides_subset}" + ) + continue + + logger.info( + f"Predicting subsequences for peptides: {peptides_subset} and length: {length}" + ) + + binding_predictions = predictor.predict_subsequences( + {f"seq{i}": seq for i, seq in enumerate(peptides_subset)}, + peptide_lengths=[length], + ) + + if not binding_predictions: + logger.error( + f"No predictions for allele {allele} and length {length}" + ) + continue + + # Convert predictions to a DataFrame and then to a dictionary + df = binding_predictions.to_dataframe() + logger.info(f"Prediction DataFrame: {df}") + logger.info(f"Converted to dict: {df.to_dict(orient='records')}") + results.append( + { + "allele": allele, + "length": length, + "peptides": peptides_subset, + "result": df.to_dict(orient="records"), + } + ) + logger.info( + f"Successfully predicted binding affinity for allele {allele} and length {length}." + ) + except Exception as e: + logger.error( + f"Error processing allele {allele} and length {length}: {e}" + ) + results.append( + { + "allele": allele, + "length": length, + "peptides": peptides_subset, + "error": str(e), + } + ) + + return results diff --git a/apps/tools-fastapi/uv.lock b/apps/tools-fastapi/uv.lock new file mode 100644 index 0000000..f569f9b --- /dev/null +++ b/apps/tools-fastapi/uv.lock @@ -0,0 +1,1663 @@ +version = 1 +requires-python = ">=3.12" +resolution-markers = [ + "python_full_version < '3.13'", + "python_full_version >= '3.13'", +] + +[[package]] +name = "absl-py" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7a/8f/fc001b92ecc467cc32ab38398bd0bfb45df46e7523bf33c2ad22a505f06e/absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff", size = 118055 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/ad/e0d3c824784ff121c03cc031f944bc7e139a8f1870ffd2845cc2dd76f6c4/absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308", size = 133706 }, +] + +[[package]] +name = "annotated-types" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 }, +] + +[[package]] +name = "anyio" +version = "4.6.2.post1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna" }, + { name = "sniffio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9f/09/45b9b7a6d4e45c6bcb5bf61d19e3ab87df68e0601fa8c5293de3542546cc/anyio-4.6.2.post1.tar.gz", hash = "sha256:4c8bc31ccdb51c7f7bd251f51c609e038d63e34219b44aa86e47576389880b4c", size = 173422 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e4/f5/f2b75d2fc6f1a260f340f0e7c6a060f4dd2961cc16884ed851b0d18da06a/anyio-4.6.2.post1-py3-none-any.whl", hash = "sha256:6d170c36fba3bdd840c73d3868c1e777e33676a69c3a72cf0a0d5d6d8009b61d", size = 90377 }, +] + +[[package]] +name = "appdirs" +version = "1.4.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/d8/05696357e0311f5b5c316d7b95f46c669dd9c15aaeecbb48c7d0aeb88c40/appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41", size = 13470 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128", size = 9566 }, +] + +[[package]] +name = "astroid" +version = "2.15.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "lazy-object-proxy" }, + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/58/3d/c18b0854d0d2eb3aca20c149cff5c90e6b84a5366066768d98636f5045ed/astroid-2.15.8.tar.gz", hash = "sha256:6c107453dffee9055899705de3c9ead36e74119cee151e5a9aaf7f0b0e020a6a", size = 344362 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/16/b6/c0b5394ec6149e0129421f1a762b805e0e583974bc3cd65e3c7ce7c95444/astroid-2.15.8-py3-none-any.whl", hash = "sha256:1aa149fc5c6589e3d0ece885b4491acd80af4f087baafa3fb5203b113e68cd3c", size = 278329 }, +] + +[[package]] +name = "astunparse" +version = "1.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, + { name = "wheel" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f3/af/4182184d3c338792894f34a62672919db7ca008c89abee9b564dd34d8029/astunparse-1.6.3.tar.gz", hash = "sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872", size = 18290 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2b/03/13dde6512ad7b4557eb792fbcf0c653af6076b81e5941d36ec61f7ce6028/astunparse-1.6.3-py2.py3-none-any.whl", hash = "sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8", size = 12732 }, +] + +[[package]] +name = "biopython" +version = "1.84" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9e/7f/eaca4de03f0ee06c9d578d2730fd55858a57cee3620c62d3bc17b5da5447/biopython-1.84.tar.gz", hash = "sha256:60fbe6f996e8a6866a42698c17e552127d99a9aab3259d6249fbaabd0e0cc7b4", size = 25793001 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f7/f6/a61af0d2c8c04e446bce4727e8124797132858f518b6d6543d0e7213abed/biopython-1.84-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ba58a6d76288333c5f178a426116953fa68204bd0cfc401694087dd4f96d0059", size = 2755863 }, + { url = "https://files.pythonhosted.org/packages/e9/1a/25c7df41987383070987f7b9842f48d3a33b0a78a85c2ca9d93ed810fa2a/biopython-1.84-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ee3566f6dc3acf20e238540daf896f0af20cff531521bf41fdf5143f73e209ae", size = 2738072 }, + { url = "https://files.pythonhosted.org/packages/a2/b2/c7f2a0a151208c634ac1eaa5d6345899659b1d5a700a84ef2e4f2b0e80a9/biopython-1.84-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89ef3967f5a88b5bb6344bef75ae83386de53fed3966d5c8c334ad885f8db08a", size = 3186633 }, + { url = "https://files.pythonhosted.org/packages/46/37/7db2bcbb396edba3f767dd89ac23ef5adc35c7a92ef3912c06d1e71469e1/biopython-1.84-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61765b71f84814a1eeb55ab222f43330aa7ad3e55ab91e8b444706149c67a281", size = 3206061 }, + { url = "https://files.pythonhosted.org/packages/b2/12/6c9d73cbb8c9d19ab4187aaf187f967de6e83738947b7180fdd8bc9211a2/biopython-1.84-cp312-cp312-win32.whl", hash = "sha256:52b6098f47d6b90fc8a5e8579b81ee50047e9108f0976e69c891ae0c4817e42d", size = 2756622 }, + { url = "https://files.pythonhosted.org/packages/d1/53/91d12cc254a804c797afaefec91ede04bc1f7cbd788a04ebbea9e31ee0cf/biopython-1.84-cp312-cp312-win_amd64.whl", hash = "sha256:ecff2fcf5da29b600474c0bfcdbbac0f98b25e22fe60a853d0ee798c00f7396c", size = 2792652 }, +] + +[[package]] +name = "certifi" +version = "2024.8.30" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b0/ee/9b19140fe824b367c04c5e1b369942dd754c4c5462d5674002f75c4dedc1/certifi-2024.8.30.tar.gz", hash = "sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9", size = 168507 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/12/90/3c9ff0512038035f59d279fddeb79f5f1eccd8859f06d6163c58798b9487/certifi-2024.8.30-py3-none-any.whl", hash = "sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8", size = 167321 }, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/4f/e1808dc01273379acc506d18f1504eb2d299bd4131743b9fc54d7be4df1e/charset_normalizer-3.4.0.tar.gz", hash = "sha256:223217c3d4f82c3ac5e29032b3f1c2eb0fb591b72161f86d93f5719079dae93e", size = 106620 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d3/0b/4b7a70987abf9b8196845806198975b6aab4ce016632f817ad758a5aa056/charset_normalizer-3.4.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0713f3adb9d03d49d365b70b84775d0a0d18e4ab08d12bc46baa6132ba78aaf6", size = 194445 }, + { url = "https://files.pythonhosted.org/packages/50/89/354cc56cf4dd2449715bc9a0f54f3aef3dc700d2d62d1fa5bbea53b13426/charset_normalizer-3.4.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:de7376c29d95d6719048c194a9cf1a1b0393fbe8488a22008610b0361d834ecf", size = 125275 }, + { url = "https://files.pythonhosted.org/packages/fa/44/b730e2a2580110ced837ac083d8ad222343c96bb6b66e9e4e706e4d0b6df/charset_normalizer-3.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4a51b48f42d9358460b78725283f04bddaf44a9358197b889657deba38f329db", size = 119020 }, + { url = "https://files.pythonhosted.org/packages/9d/e4/9263b8240ed9472a2ae7ddc3e516e71ef46617fe40eaa51221ccd4ad9a27/charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b295729485b06c1a0683af02a9e42d2caa9db04a373dc38a6a58cdd1e8abddf1", size = 139128 }, + { url = "https://files.pythonhosted.org/packages/6b/e3/9f73e779315a54334240353eaea75854a9a690f3f580e4bd85d977cb2204/charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee803480535c44e7f5ad00788526da7d85525cfefaf8acf8ab9a310000be4b03", size = 149277 }, + { url = "https://files.pythonhosted.org/packages/1a/cf/f1f50c2f295312edb8a548d3fa56a5c923b146cd3f24114d5adb7e7be558/charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d59d125ffbd6d552765510e3f31ed75ebac2c7470c7274195b9161a32350284", size = 142174 }, + { url = "https://files.pythonhosted.org/packages/16/92/92a76dc2ff3a12e69ba94e7e05168d37d0345fa08c87e1fe24d0c2a42223/charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8cda06946eac330cbe6598f77bb54e690b4ca93f593dee1568ad22b04f347c15", size = 143838 }, + { url = "https://files.pythonhosted.org/packages/a4/01/2117ff2b1dfc61695daf2babe4a874bca328489afa85952440b59819e9d7/charset_normalizer-3.4.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07afec21bbbbf8a5cc3651aa96b980afe2526e7f048fdfb7f1014d84acc8b6d8", size = 146149 }, + { url = "https://files.pythonhosted.org/packages/f6/9b/93a332b8d25b347f6839ca0a61b7f0287b0930216994e8bf67a75d050255/charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6b40e8d38afe634559e398cc32b1472f376a4099c75fe6299ae607e404c033b2", size = 140043 }, + { url = "https://files.pythonhosted.org/packages/ab/f6/7ac4a01adcdecbc7a7587767c776d53d369b8b971382b91211489535acf0/charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b8dcd239c743aa2f9c22ce674a145e0a25cb1566c495928440a181ca1ccf6719", size = 148229 }, + { url = "https://files.pythonhosted.org/packages/9d/be/5708ad18161dee7dc6a0f7e6cf3a88ea6279c3e8484844c0590e50e803ef/charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:84450ba661fb96e9fd67629b93d2941c871ca86fc38d835d19d4225ff946a631", size = 151556 }, + { url = "https://files.pythonhosted.org/packages/5a/bb/3d8bc22bacb9eb89785e83e6723f9888265f3a0de3b9ce724d66bd49884e/charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:44aeb140295a2f0659e113b31cfe92c9061622cadbc9e2a2f7b8ef6b1e29ef4b", size = 149772 }, + { url = "https://files.pythonhosted.org/packages/f7/fa/d3fc622de05a86f30beea5fc4e9ac46aead4731e73fd9055496732bcc0a4/charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1db4e7fefefd0f548d73e2e2e041f9df5c59e178b4c72fbac4cc6f535cfb1565", size = 144800 }, + { url = "https://files.pythonhosted.org/packages/9a/65/bdb9bc496d7d190d725e96816e20e2ae3a6fa42a5cac99c3c3d6ff884118/charset_normalizer-3.4.0-cp312-cp312-win32.whl", hash = "sha256:5726cf76c982532c1863fb64d8c6dd0e4c90b6ece9feb06c9f202417a31f7dd7", size = 94836 }, + { url = "https://files.pythonhosted.org/packages/3e/67/7b72b69d25b89c0b3cea583ee372c43aa24df15f0e0f8d3982c57804984b/charset_normalizer-3.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:b197e7094f232959f8f20541ead1d9862ac5ebea1d58e9849c1bf979255dfac9", size = 102187 }, + { url = "https://files.pythonhosted.org/packages/f3/89/68a4c86f1a0002810a27f12e9a7b22feb198c59b2f05231349fbce5c06f4/charset_normalizer-3.4.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dd4eda173a9fcccb5f2e2bd2a9f423d180194b1bf17cf59e3269899235b2a114", size = 194617 }, + { url = "https://files.pythonhosted.org/packages/4f/cd/8947fe425e2ab0aa57aceb7807af13a0e4162cd21eee42ef5b053447edf5/charset_normalizer-3.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9e3c4c9e1ed40ea53acf11e2a386383c3304212c965773704e4603d589343ed", size = 125310 }, + { url = "https://files.pythonhosted.org/packages/5b/f0/b5263e8668a4ee9becc2b451ed909e9c27058337fda5b8c49588183c267a/charset_normalizer-3.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:92a7e36b000bf022ef3dbb9c46bfe2d52c047d5e3f3343f43204263c5addc250", size = 119126 }, + { url = "https://files.pythonhosted.org/packages/ff/6e/e445afe4f7fda27a533f3234b627b3e515a1b9429bc981c9a5e2aa5d97b6/charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54b6a92d009cbe2fb11054ba694bc9e284dad30a26757b1e372a1fdddaf21920", size = 139342 }, + { url = "https://files.pythonhosted.org/packages/a1/b2/4af9993b532d93270538ad4926c8e37dc29f2111c36f9c629840c57cd9b3/charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ffd9493de4c922f2a38c2bf62b831dcec90ac673ed1ca182fe11b4d8e9f2a64", size = 149383 }, + { url = "https://files.pythonhosted.org/packages/fb/6f/4e78c3b97686b871db9be6f31d64e9264e889f8c9d7ab33c771f847f79b7/charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:35c404d74c2926d0287fbd63ed5d27eb911eb9e4a3bb2c6d294f3cfd4a9e0c23", size = 142214 }, + { url = "https://files.pythonhosted.org/packages/2b/c9/1c8fe3ce05d30c87eff498592c89015b19fade13df42850aafae09e94f35/charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4796efc4faf6b53a18e3d46343535caed491776a22af773f366534056c4e1fbc", size = 144104 }, + { url = "https://files.pythonhosted.org/packages/ee/68/efad5dcb306bf37db7db338338e7bb8ebd8cf38ee5bbd5ceaaaa46f257e6/charset_normalizer-3.4.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7fdd52961feb4c96507aa649550ec2a0d527c086d284749b2f582f2d40a2e0d", size = 146255 }, + { url = "https://files.pythonhosted.org/packages/0c/75/1ed813c3ffd200b1f3e71121c95da3f79e6d2a96120163443b3ad1057505/charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:92db3c28b5b2a273346bebb24857fda45601aef6ae1c011c0a997106581e8a88", size = 140251 }, + { url = "https://files.pythonhosted.org/packages/7d/0d/6f32255c1979653b448d3c709583557a4d24ff97ac4f3a5be156b2e6a210/charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ab973df98fc99ab39080bfb0eb3a925181454d7c3ac8a1e695fddfae696d9e90", size = 148474 }, + { url = "https://files.pythonhosted.org/packages/ac/a0/c1b5298de4670d997101fef95b97ac440e8c8d8b4efa5a4d1ef44af82f0d/charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:4b67fdab07fdd3c10bb21edab3cbfe8cf5696f453afce75d815d9d7223fbe88b", size = 151849 }, + { url = "https://files.pythonhosted.org/packages/04/4f/b3961ba0c664989ba63e30595a3ed0875d6790ff26671e2aae2fdc28a399/charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:aa41e526a5d4a9dfcfbab0716c7e8a1b215abd3f3df5a45cf18a12721d31cb5d", size = 149781 }, + { url = "https://files.pythonhosted.org/packages/d8/90/6af4cd042066a4adad58ae25648a12c09c879efa4849c705719ba1b23d8c/charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ffc519621dce0c767e96b9c53f09c5d215578e10b02c285809f76509a3931482", size = 144970 }, + { url = "https://files.pythonhosted.org/packages/cc/67/e5e7e0cbfefc4ca79025238b43cdf8a2037854195b37d6417f3d0895c4c2/charset_normalizer-3.4.0-cp313-cp313-win32.whl", hash = "sha256:f19c1585933c82098c2a520f8ec1227f20e339e33aca8fa6f956f6691b784e67", size = 94973 }, + { url = "https://files.pythonhosted.org/packages/65/97/fc9bbc54ee13d33dc54a7fcf17b26368b18505500fc01e228c27b5222d80/charset_normalizer-3.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:707b82d19e65c9bd28b81dde95249b07bf9f5b90ebe1ef17d9b57473f8a64b7b", size = 102308 }, + { url = "https://files.pythonhosted.org/packages/bf/9b/08c0432272d77b04803958a4598a51e2a4b51c06640af8b8f0f908c18bf2/charset_normalizer-3.4.0-py3-none-any.whl", hash = "sha256:fe9f97feb71aa9896b81973a7bbada8c49501dc73e58a10fcef6663af95e5079", size = 49446 }, +] + +[[package]] +name = "click" +version = "8.1.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "platform_system == 'Windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/2e/d53fa4befbf2cfa713304affc7ca780ce4fc1fd8710527771b58311a3229/click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28", size = 97941 }, +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, +] + +[[package]] +name = "datacache" +version = "1.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "appdirs" }, + { name = "mock" }, + { name = "pandas" }, + { name = "progressbar33" }, + { name = "requests" }, + { name = "typechecks" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/27/42/7ed665dad2b9acaab52965eb98a3f9fc34ae77dfc7706afbe583b986a846/datacache-1.4.1.tar.gz", hash = "sha256:acfdbdc550c7e0971b42ec0575cec2245ec4ace1cdfb9e931852db157e5b6c67", size = 19063 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/63/bf/585fd1522561df26ae44caf40c827cb2f0ec3de5b2bb00d26117ad6b7663/datacache-1.4.1-py3-none-any.whl", hash = "sha256:a334747a2b849d7e4aa09aff3bc338ee10a97f02e5fe0237b6bb9254b4779635", size = 20258 }, +] + +[[package]] +name = "dill" +version = "0.3.9" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/70/43/86fe3f9e130c4137b0f1b50784dd70a5087b911fe07fa81e53e0c4c47fea/dill-0.3.9.tar.gz", hash = "sha256:81aa267dddf68cbfe8029c42ca9ec6a4ab3b22371d1c450abc54422577b4512c", size = 187000 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/d1/e73b6ad76f0b1fb7f23c35c6d95dbc506a9c8804f43dda8cb5b0fa6331fd/dill-0.3.9-py3-none-any.whl", hash = "sha256:468dff3b89520b474c0397703366b7b95eebe6303f108adf9b19da1f702be87a", size = 119418 }, +] + +[[package]] +name = "dnspython" +version = "2.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/4a/263763cb2ba3816dd94b08ad3a33d5fdae34ecb856678773cc40a3605829/dnspython-2.7.0.tar.gz", hash = "sha256:ce9c432eda0dc91cf618a5cedf1a4e142651196bbcd2c80e89ed5a907e5cfaf1", size = 345197 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/1b/e0a87d256e40e8c888847551b20a017a6b98139178505dc7ffb96f04e954/dnspython-2.7.0-py3-none-any.whl", hash = "sha256:b4c34b7d10b51bcc3a5071e7b8dee77939f1e878477eeecc965e9835f63c6c86", size = 313632 }, +] + +[[package]] +name = "email-validator" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dnspython" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/48/ce/13508a1ec3f8bb981ae4ca79ea40384becc868bfae97fd1c942bb3a001b1/email_validator-2.2.0.tar.gz", hash = "sha256:cb690f344c617a714f22e66ae771445a1ceb46821152df8e165c5f9a364582b7", size = 48967 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/ee/bf0adb559ad3c786f12bcbc9296b3f5675f529199bef03e2df281fa1fadb/email_validator-2.2.0-py3-none-any.whl", hash = "sha256:561977c2d73ce3611850a06fa56b414621e0c8faa9d66f2611407d87465da631", size = 33521 }, +] + +[[package]] +name = "fastapi" +version = "0.115.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/22/fa/19e3c7c9b31ac291987c82e959f36f88840bea183fa3dc3bb654669f19c1/fastapi-0.115.2.tar.gz", hash = "sha256:3995739e0b09fa12f984bce8fa9ae197b35d433750d3d312422d846e283697ee", size = 299968 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/14/bbe7776356ef01f830f8085ca3ac2aea59c73727b6ffaa757abeb7d2900b/fastapi-0.115.2-py3-none-any.whl", hash = "sha256:61704c71286579cc5a598763905928f24ee98bfcc07aabe84cfefb98812bbc86", size = 94650 }, +] + +[package.optional-dependencies] +standard = [ + { name = "email-validator" }, + { name = "fastapi-cli", extra = ["standard"] }, + { name = "httpx" }, + { name = "jinja2" }, + { name = "python-multipart" }, + { name = "uvicorn", extra = ["standard"] }, +] + +[[package]] +name = "fastapi-cli" +version = "0.0.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typer" }, + { name = "uvicorn", extra = ["standard"] }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c5/f8/1ad5ce32d029aeb9117e9a5a9b3e314a8477525d60c12a9b7730a3c186ec/fastapi_cli-0.0.5.tar.gz", hash = "sha256:d30e1239c6f46fcb95e606f02cdda59a1e2fa778a54b64686b3ff27f6211ff9f", size = 15571 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/24/ea/4b5011012ac925fe2f83b19d0e09cee9d324141ec7bf5e78bb2817f96513/fastapi_cli-0.0.5-py3-none-any.whl", hash = "sha256:e94d847524648c748a5350673546bbf9bcaeb086b33c24f2e82e021436866a46", size = 9489 }, +] + +[package.optional-dependencies] +standard = [ + { name = "uvicorn", extra = ["standard"] }, +] + +[[package]] +name = "flatbuffers" +version = "24.3.25" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a9/74/2df95ef84b214d2bee0886d572775a6f38793f5ca6d7630c3239c91104ac/flatbuffers-24.3.25.tar.gz", hash = "sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4", size = 22139 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/41/f0/7e988a019bc54b2dbd0ad4182ef2d53488bb02e58694cd79d61369e85900/flatbuffers-24.3.25-py2.py3-none-any.whl", hash = "sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812", size = 26784 }, +] + +[[package]] +name = "gast" +version = "0.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3c/14/c566f5ca00c115db7725263408ff952b8ae6d6a4e792ef9c84e77d9af7a1/gast-0.6.0.tar.gz", hash = "sha256:88fc5300d32c7ac6ca7b515310862f71e6fdf2c029bbec7c66c0f5dd47b6b1fb", size = 27708 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a3/61/8001b38461d751cd1a0c3a6ae84346796a5758123f3ed97a1b121dfbf4f3/gast-0.6.0-py3-none-any.whl", hash = "sha256:52b182313f7330389f72b069ba00f174cfe2a06411099547288839c6cbafbd54", size = 21173 }, +] + +[[package]] +name = "google-pasta" +version = "0.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/35/4a/0bd53b36ff0323d10d5f24ebd67af2de10a1117f5cf4d7add90df92756f1/google-pasta-0.2.0.tar.gz", hash = "sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e", size = 40430 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a3/de/c648ef6835192e6e2cc03f40b19eeda4382c49b5bafb43d88b931c4c74ac/google_pasta-0.2.0-py3-none-any.whl", hash = "sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed", size = 57471 }, +] + +[[package]] +name = "grpcio" +version = "1.67.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/ae/3c47d71ab4abd4bd60a7e2806071fe0a4b6937b9eabe522291787087ea1f/grpcio-1.67.0.tar.gz", hash = "sha256:e090b2553e0da1c875449c8e75073dd4415dd71c9bde6a406240fdf4c0ee467c", size = 12569330 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b0/2d/b2a783f1d93735a259676de5558ef019ac3511e894b8e9d224edc0d7d034/grpcio-1.67.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:227316b5631260e0bef8a3ce04fa7db4cc81756fea1258b007950b6efc90c05d", size = 5086495 }, + { url = "https://files.pythonhosted.org/packages/7b/13/c1f537a88dad543ca0a7be4dfee80a21b3b02b7df27750997777355e5840/grpcio-1.67.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d90cfdafcf4b45a7a076e3e2a58e7bc3d59c698c4f6470b0bb13a4d869cf2273", size = 10979109 }, + { url = "https://files.pythonhosted.org/packages/b7/83/d7cb72f2202fe8d608d25c7e9d6d75184bf6ef658688c818821add102211/grpcio-1.67.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:77196216d5dd6f99af1c51e235af2dd339159f657280e65ce7e12c1a8feffd1d", size = 5586952 }, + { url = "https://files.pythonhosted.org/packages/e5/18/8df585d0158af9e2b46ee2388bdb21de0e7f5bf4a47a86a861ebdbf947b5/grpcio-1.67.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:15c05a26a0f7047f720da41dc49406b395c1470eef44ff7e2c506a47ac2c0591", size = 6212460 }, + { url = "https://files.pythonhosted.org/packages/47/46/027f8943113961784ce1eb69a28544d9a62ffb286332820ba634d979c91c/grpcio-1.67.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3840994689cc8cbb73d60485c594424ad8adb56c71a30d8948d6453083624b52", size = 5849002 }, + { url = "https://files.pythonhosted.org/packages/eb/26/fb19d5bc277e665382c835d7af1f8c1e3197576eed76327824d79e2a4bef/grpcio-1.67.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:5a1e03c3102b6451028d5dc9f8591131d6ab3c8a0e023d94c28cb930ed4b5f81", size = 6568222 }, + { url = "https://files.pythonhosted.org/packages/e0/cc/387efa986f166c068d48331c699e6ee662a057371065f35d3ca1bc09d799/grpcio-1.67.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:682968427a63d898759474e3b3178d42546e878fdce034fd7474ef75143b64e3", size = 6148002 }, + { url = "https://files.pythonhosted.org/packages/24/57/529504e3e3e910f0537a0a36184cb7241d0d111109d6588096a9f8c139bf/grpcio-1.67.0-cp312-cp312-win32.whl", hash = "sha256:d01793653248f49cf47e5695e0a79805b1d9d4eacef85b310118ba1dfcd1b955", size = 3596220 }, + { url = "https://files.pythonhosted.org/packages/1d/1f/acf03ee901313446d52c3916d527d4981de9f6f3edc69267d05509dcfa7b/grpcio-1.67.0-cp312-cp312-win_amd64.whl", hash = "sha256:985b2686f786f3e20326c4367eebdaed3e7aa65848260ff0c6644f817042cb15", size = 4343545 }, + { url = "https://files.pythonhosted.org/packages/7a/e7/cc7feccb18ef0b5aa67ccb7859a091fa836c5d361a0109b9fca578e59e64/grpcio-1.67.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:8c9a35b8bc50db35ab8e3e02a4f2a35cfba46c8705c3911c34ce343bd777813a", size = 5087009 }, + { url = "https://files.pythonhosted.org/packages/bd/56/10175f4b1600b16e601680df053361924a9fcd9e1c0ad9b8bd1ba2b4c864/grpcio-1.67.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:42199e704095b62688998c2d84c89e59a26a7d5d32eed86d43dc90e7a3bd04aa", size = 10937553 }, + { url = "https://files.pythonhosted.org/packages/aa/85/115538b1aeb09d66c6e637608a56eddacd59eb71ab0161ad59172c01d436/grpcio-1.67.0-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:c4c425f440fb81f8d0237c07b9322fc0fb6ee2b29fbef5f62a322ff8fcce240d", size = 5586507 }, + { url = "https://files.pythonhosted.org/packages/0f/db/f402a455e287154683235183c2843c27fffe2fc03fa4c45b57dd90011401/grpcio-1.67.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:323741b6699cd2b04a71cb38f502db98f90532e8a40cb675393d248126a268af", size = 6211948 }, + { url = "https://files.pythonhosted.org/packages/92/e4/5957806105aad556f7df6a420b6c69044b6f707926392118772a8ba96de4/grpcio-1.67.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:662c8e105c5e5cee0317d500eb186ed7a93229586e431c1bf0c9236c2407352c", size = 5849392 }, + { url = "https://files.pythonhosted.org/packages/88/ab/c496a406f4682c56e933bef6b0ed22b9eaec84c6915f83d5cddd94126e16/grpcio-1.67.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:f6bd2ab135c64a4d1e9e44679a616c9bc944547357c830fafea5c3caa3de5153", size = 6571359 }, + { url = "https://files.pythonhosted.org/packages/9e/a8/96b3ef565791d7282c300c07c2a7080471311e7d5a239db15678aaac47eb/grpcio-1.67.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:2f55c1e0e2ae9bdd23b3c63459ee4c06d223b68aeb1961d83c48fb63dc29bc03", size = 6147905 }, + { url = "https://files.pythonhosted.org/packages/cd/b7/846cc563209ff5af88bc7dcb269948210674c2f743e7fd8e1a2ad9708e89/grpcio-1.67.0-cp313-cp313-win32.whl", hash = "sha256:fd6bc27861e460fe28e94226e3673d46e294ca4673d46b224428d197c5935e69", size = 3594603 }, + { url = "https://files.pythonhosted.org/packages/bd/74/49d27908b369b72fd3373ec0f16d7f58614fb7101cb38b266afeab846cca/grpcio-1.67.0-cp313-cp313-win_amd64.whl", hash = "sha256:cf51d28063338608cd8d3cd64677e922134837902b70ce00dad7f116e3998210", size = 4345468 }, +] + +[[package]] +name = "gtfparse" +version = "2.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "polars" }, + { name = "pyarrow" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/91/e9/f576f57742fe8fcb8086eb72c6de60e75e7adce7f6dc582d34b5143fb1ab/gtfparse-2.5.0.tar.gz", hash = "sha256:9fea54811cd87f597a110a49dc1b1b6a3325ffb7d1f36ecc62c32d14d3eb9493", size = 17228 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/4b/0cb91cedef2b9e93f340166e8709587e3c36366fff4964973ff0d38908ba/gtfparse-2.5.0-py3-none-any.whl", hash = "sha256:ccc9e9e77b7bdd90dda0e41da864714cb40b6b0c64ecc1d8a131e11497357140", size = 15450 }, +] + +[[package]] +name = "h11" +version = "0.14.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f5/38/3af3d3633a34a3316095b39c8e8fb4853a28a536e55d347bd8d8e9a14b03/h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d", size = 100418 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259 }, +] + +[[package]] +name = "h5py" +version = "3.12.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cc/0c/5c2b0a88158682aeafb10c1c2b735df5bc31f165bfe192f2ee9f2a23b5f1/h5py-3.12.1.tar.gz", hash = "sha256:326d70b53d31baa61f00b8aa5f95c2fcb9621a3ee8365d770c551a13dbbcbfdf", size = 411457 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d4/e1/ea9bfe18a3075cdc873f0588ff26ce394726047653557876d7101bf0c74e/h5py-3.12.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:06a903a4e4e9e3ebbc8b548959c3c2552ca2d70dac14fcfa650d9261c66939ed", size = 3372538 }, + { url = "https://files.pythonhosted.org/packages/0d/74/1009b663387c025e8fa5f3ee3cf3cd0d99b1ad5c72eeb70e75366b1ce878/h5py-3.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7b3b8f3b48717e46c6a790e3128d39c61ab595ae0a7237f06dfad6a3b51d5351", size = 2868104 }, + { url = "https://files.pythonhosted.org/packages/af/52/c604adc06280c15a29037d4aa79a24fe54d8d0b51085e81ed24b2fa995f7/h5py-3.12.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:050a4f2c9126054515169c49cb900949814987f0c7ae74c341b0c9f9b5056834", size = 5194606 }, + { url = "https://files.pythonhosted.org/packages/fa/63/eeaacff417b393491beebabb8a3dc5342950409eb6d7b39d437289abdbae/h5py-3.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c4b41d1019322a5afc5082864dfd6359f8935ecd37c11ac0029be78c5d112c9", size = 5413256 }, + { url = "https://files.pythonhosted.org/packages/86/f7/bb465dcb92ca3521a15cbe1031f6d18234dbf1fb52a6796a00bfaa846ebf/h5py-3.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:e4d51919110a030913201422fb07987db4338eba5ec8c5a15d6fab8e03d443fc", size = 2993055 }, + { url = "https://files.pythonhosted.org/packages/23/1c/ecdd0efab52c24f2a9bf2324289828b860e8dd1e3c5ada3cf0889e14fdc1/h5py-3.12.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:513171e90ed92236fc2ca363ce7a2fc6f2827375efcbb0cc7fbdd7fe11fecafc", size = 3346239 }, + { url = "https://files.pythonhosted.org/packages/93/cd/5b6f574bf3e318bbe305bc93ba45181676550eb44ba35e006d2e98004eaa/h5py-3.12.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:59400f88343b79655a242068a9c900001a34b63e3afb040bd7cdf717e440f653", size = 2843416 }, + { url = "https://files.pythonhosted.org/packages/8a/4f/b74332f313bfbe94ba03fff784219b9db385e6139708e55b11490149f90a/h5py-3.12.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3e465aee0ec353949f0f46bf6c6f9790a2006af896cee7c178a8c3e5090aa32", size = 5154390 }, + { url = "https://files.pythonhosted.org/packages/1a/57/93ea9e10a6457ea8d3b867207deb29a527e966a08a84c57ffd954e32152a/h5py-3.12.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba51c0c5e029bb5420a343586ff79d56e7455d496d18a30309616fdbeed1068f", size = 5378244 }, + { url = "https://files.pythonhosted.org/packages/50/51/0bbf3663062b2eeee78aa51da71e065f8a0a6e3cb950cc7020b4444999e6/h5py-3.12.1-cp313-cp313-win_amd64.whl", hash = "sha256:52ab036c6c97055b85b2a242cb540ff9590bacfda0c03dd0cf0661b311f522f8", size = 2979760 }, +] + +[[package]] +name = "httpcore" +version = "1.0.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b6/44/ed0fa6a17845fb033bd885c03e842f08c1b9406c86a2e60ac1ae1b9206a6/httpcore-1.0.6.tar.gz", hash = "sha256:73f6dbd6eb8c21bbf7ef8efad555481853f5f6acdeaff1edb0694289269ee17f", size = 85180 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/06/89/b161908e2f51be56568184aeb4a880fd287178d176fd1c860d2217f41106/httpcore-1.0.6-py3-none-any.whl", hash = "sha256:27b59625743b85577a8c0e10e55b50b5368a4f2cfe8cc7bcfa9cf00829c2682f", size = 78011 }, +] + +[[package]] +name = "httptools" +version = "0.6.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a7/9a/ce5e1f7e131522e6d3426e8e7a490b3a01f39a6696602e1c4f33f9e94277/httptools-0.6.4.tar.gz", hash = "sha256:4e93eee4add6493b59a5c514da98c939b244fce4a0d8879cd3f466562f4b7d5c", size = 240639 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/0e/d0b71465c66b9185f90a091ab36389a7352985fe857e352801c39d6127c8/httptools-0.6.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:df017d6c780287d5c80601dafa31f17bddb170232d85c066604d8558683711a2", size = 200683 }, + { url = "https://files.pythonhosted.org/packages/e2/b8/412a9bb28d0a8988de3296e01efa0bd62068b33856cdda47fe1b5e890954/httptools-0.6.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:85071a1e8c2d051b507161f6c3e26155b5c790e4e28d7f236422dbacc2a9cc44", size = 104337 }, + { url = "https://files.pythonhosted.org/packages/9b/01/6fb20be3196ffdc8eeec4e653bc2a275eca7f36634c86302242c4fbb2760/httptools-0.6.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69422b7f458c5af875922cdb5bd586cc1f1033295aa9ff63ee196a87519ac8e1", size = 508796 }, + { url = "https://files.pythonhosted.org/packages/f7/d8/b644c44acc1368938317d76ac991c9bba1166311880bcc0ac297cb9d6bd7/httptools-0.6.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16e603a3bff50db08cd578d54f07032ca1631450ceb972c2f834c2b860c28ea2", size = 510837 }, + { url = "https://files.pythonhosted.org/packages/52/d8/254d16a31d543073a0e57f1c329ca7378d8924e7e292eda72d0064987486/httptools-0.6.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ec4f178901fa1834d4a060320d2f3abc5c9e39766953d038f1458cb885f47e81", size = 485289 }, + { url = "https://files.pythonhosted.org/packages/5f/3c/4aee161b4b7a971660b8be71a92c24d6c64372c1ab3ae7f366b3680df20f/httptools-0.6.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f9eb89ecf8b290f2e293325c646a211ff1c2493222798bb80a530c5e7502494f", size = 489779 }, + { url = "https://files.pythonhosted.org/packages/12/b7/5cae71a8868e555f3f67a50ee7f673ce36eac970f029c0c5e9d584352961/httptools-0.6.4-cp312-cp312-win_amd64.whl", hash = "sha256:db78cb9ca56b59b016e64b6031eda5653be0589dba2b1b43453f6e8b405a0970", size = 88634 }, + { url = "https://files.pythonhosted.org/packages/94/a3/9fe9ad23fd35f7de6b91eeb60848986058bd8b5a5c1e256f5860a160cc3e/httptools-0.6.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ade273d7e767d5fae13fa637f4d53b6e961fb7fd93c7797562663f0171c26660", size = 197214 }, + { url = "https://files.pythonhosted.org/packages/ea/d9/82d5e68bab783b632023f2fa31db20bebb4e89dfc4d2293945fd68484ee4/httptools-0.6.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:856f4bc0478ae143bad54a4242fccb1f3f86a6e1be5548fecfd4102061b3a083", size = 102431 }, + { url = "https://files.pythonhosted.org/packages/96/c1/cb499655cbdbfb57b577734fde02f6fa0bbc3fe9fb4d87b742b512908dff/httptools-0.6.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:322d20ea9cdd1fa98bd6a74b77e2ec5b818abdc3d36695ab402a0de8ef2865a3", size = 473121 }, + { url = "https://files.pythonhosted.org/packages/af/71/ee32fd358f8a3bb199b03261f10921716990808a675d8160b5383487a317/httptools-0.6.4-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d87b29bd4486c0093fc64dea80231f7c7f7eb4dc70ae394d70a495ab8436071", size = 473805 }, + { url = "https://files.pythonhosted.org/packages/8a/0a/0d4df132bfca1507114198b766f1737d57580c9ad1cf93c1ff673e3387be/httptools-0.6.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:342dd6946aa6bda4b8f18c734576106b8a31f2fe31492881a9a160ec84ff4bd5", size = 448858 }, + { url = "https://files.pythonhosted.org/packages/1e/6a/787004fdef2cabea27bad1073bf6a33f2437b4dbd3b6fb4a9d71172b1c7c/httptools-0.6.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4b36913ba52008249223042dca46e69967985fb4051951f94357ea681e1f5dc0", size = 452042 }, + { url = "https://files.pythonhosted.org/packages/4d/dc/7decab5c404d1d2cdc1bb330b1bf70e83d6af0396fd4fc76fc60c0d522bf/httptools-0.6.4-cp313-cp313-win_amd64.whl", hash = "sha256:28908df1b9bb8187393d5b5db91435ccc9c8e891657f9cbb42a2541b44c82fc8", size = 87682 }, +] + +[[package]] +name = "httpx" +version = "0.27.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "certifi" }, + { name = "httpcore" }, + { name = "idna" }, + { name = "sniffio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/78/82/08f8c936781f67d9e6b9eeb8a0c8b4e406136ea4c3d1f89a5db71d42e0e6/httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2", size = 144189 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/95/9377bcb415797e44274b51d46e3249eba641711cf3348050f76ee7b15ffc/httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0", size = 76395 }, +] + +[[package]] +name = "idna" +version = "3.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 }, +] + +[[package]] +name = "iniconfig" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/4b/cbd8e699e64a6f16ca3a8220661b5f83792b3017d0f79807cb8708d33913/iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", size = 4646 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 }, +] + +[[package]] +name = "isort" +version = "5.13.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/87/f9/c1eb8635a24e87ade2efce21e3ce8cd6b8630bb685ddc9cdaca1349b2eb5/isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109", size = 175303 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/b3/8def84f539e7d2289a02f0524b944b15d7c75dab7628bedf1c4f0992029c/isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6", size = 92310 }, +] + +[[package]] +name = "jinja2" +version = "3.1.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ed/55/39036716d19cab0747a5020fc7e907f362fbf48c984b14e62127f7e68e5d/jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369", size = 240245 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/31/80/3a54838c3fb461f6fec263ebf3a3a41771bd05190238de3486aae8540c36/jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d", size = 133271 }, +] + +[[package]] +name = "joblib" +version = "1.4.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/64/33/60135848598c076ce4b231e1b1895170f45fbcaeaa2c9d5e38b04db70c35/joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e", size = 2116621 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/29/df4b9b42f2be0b623cbd5e2140cafcaa2bef0759a00b7b70104dcfe2fb51/joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6", size = 301817 }, +] + +[[package]] +name = "keras" +version = "3.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "absl-py" }, + { name = "h5py" }, + { name = "ml-dtypes" }, + { name = "namex" }, + { name = "numpy" }, + { name = "optree" }, + { name = "packaging" }, + { name = "rich" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/47/22/448401abc7deaee90592c48c5dcc27ad93518b605655bef7ec5eb9544fe5/keras-3.6.0.tar.gz", hash = "sha256:405727525a3522ed8f9ec0b46e0667e4c65fcf714a067322c16a00d902ded41d", size = 890295 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/88/eef50051a772dcb4433d1f3e4c1d6576ba450fe83e89d028d7e8b85a2122/keras-3.6.0-py3-none-any.whl", hash = "sha256:49585e4577f6e86bd890d96dfbcb1890f5bab5967ef831c07fd63f9d86e4bfe9", size = 1191019 }, +] + +[[package]] +name = "lazy-object-proxy" +version = "1.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/f0/f02e2d150d581a294efded4020094a371bbab42423fe78625ac18854d89b/lazy-object-proxy-1.10.0.tar.gz", hash = "sha256:78247b6d45f43a52ef35c25b5581459e85117225408a4128a3daf8bf9648ac69", size = 43271 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/5d/768a7f2ccebb29604def61842fd54f6f5f75c79e366ee8748dda84de0b13/lazy_object_proxy-1.10.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e98c8af98d5707dcdecc9ab0863c0ea6e88545d42ca7c3feffb6b4d1e370c7ba", size = 27560 }, + { url = "https://files.pythonhosted.org/packages/b3/ce/f369815549dbfa4bebed541fa4e1561d69e4f268a1f6f77da886df182dab/lazy_object_proxy-1.10.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:952c81d415b9b80ea261d2372d2a4a2332a3890c2b83e0535f263ddfe43f0d43", size = 72403 }, + { url = "https://files.pythonhosted.org/packages/44/46/3771e0a4315044aa7b67da892b2fb1f59dfcf0eaff2c8967b2a0a85d5896/lazy_object_proxy-1.10.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80b39d3a151309efc8cc48675918891b865bdf742a8616a337cb0090791a0de9", size = 72401 }, + { url = "https://files.pythonhosted.org/packages/81/39/84ce4740718e1c700bd04d3457ac92b2e9ce76529911583e7a2bf4d96eb2/lazy_object_proxy-1.10.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e221060b701e2aa2ea991542900dd13907a5c90fa80e199dbf5a03359019e7a3", size = 75375 }, + { url = "https://files.pythonhosted.org/packages/86/3b/d6b65da2b864822324745c0a73fe7fd86c67ccea54173682c3081d7adea8/lazy_object_proxy-1.10.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:92f09ff65ecff3108e56526f9e2481b8116c0b9e1425325e13245abfd79bdb1b", size = 75466 }, + { url = "https://files.pythonhosted.org/packages/f5/33/467a093bf004a70022cb410c590d937134bba2faa17bf9dc42a48f49af35/lazy_object_proxy-1.10.0-cp312-cp312-win32.whl", hash = "sha256:3ad54b9ddbe20ae9f7c1b29e52f123120772b06dbb18ec6be9101369d63a4074", size = 25914 }, + { url = "https://files.pythonhosted.org/packages/77/ce/7956dc5ac2f8b62291b798c8363c81810e22a9effe469629d297d087e350/lazy_object_proxy-1.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:127a789c75151db6af398b8972178afe6bda7d6f68730c057fbbc2e96b08d282", size = 27525 }, + { url = "https://files.pythonhosted.org/packages/31/8b/94dc8d58704ab87b39faed6f2fc0090b9d90e2e2aa2bbec35c79f3d2a054/lazy_object_proxy-1.10.0-pp310.pp311.pp312.pp38.pp39-none-any.whl", hash = "sha256:80fa48bd89c8f2f456fc0765c11c23bf5af827febacd2f523ca5bc1893fcc09d", size = 16405 }, +] + +[[package]] +name = "libclang" +version = "18.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6e/5c/ca35e19a4f142adffa27e3d652196b7362fa612243e2b916845d801454fc/libclang-18.1.1.tar.gz", hash = "sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250", size = 39612 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/49/f5e3e7e1419872b69f6f5e82ba56e33955a74bd537d8a1f5f1eff2f3668a/libclang-18.1.1-1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a", size = 25836045 }, + { url = "https://files.pythonhosted.org/packages/e2/e5/fc61bbded91a8830ccce94c5294ecd6e88e496cc85f6704bf350c0634b70/libclang-18.1.1-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5", size = 26502641 }, + { url = "https://files.pythonhosted.org/packages/db/ed/1df62b44db2583375f6a8a5e2ca5432bbdc3edb477942b9b7c848c720055/libclang-18.1.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8", size = 26420207 }, + { url = "https://files.pythonhosted.org/packages/1d/fc/716c1e62e512ef1c160e7984a73a5fc7df45166f2ff3f254e71c58076f7c/libclang-18.1.1-py2.py3-none-manylinux2010_x86_64.whl", hash = "sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b", size = 24515943 }, + { url = "https://files.pythonhosted.org/packages/3c/3d/f0ac1150280d8d20d059608cf2d5ff61b7c3b7f7bcf9c0f425ab92df769a/libclang-18.1.1-py2.py3-none-manylinux2014_aarch64.whl", hash = "sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592", size = 23784972 }, + { url = "https://files.pythonhosted.org/packages/fe/2f/d920822c2b1ce9326a4c78c0c2b4aa3fde610c7ee9f631b600acb5376c26/libclang-18.1.1-py2.py3-none-manylinux2014_armv7l.whl", hash = "sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe", size = 20259606 }, + { url = "https://files.pythonhosted.org/packages/2d/c2/de1db8c6d413597076a4259cea409b83459b2db997c003578affdd32bf66/libclang-18.1.1-py2.py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f", size = 24921494 }, + { url = "https://files.pythonhosted.org/packages/0b/2d/3f480b1e1d31eb3d6de5e3ef641954e5c67430d5ac93b7fa7e07589576c7/libclang-18.1.1-py2.py3-none-win_amd64.whl", hash = "sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb", size = 26415083 }, + { url = "https://files.pythonhosted.org/packages/71/cf/e01dc4cc79779cd82d77888a88ae2fa424d93b445ad4f6c02bfc18335b70/libclang-18.1.1-py2.py3-none-win_arm64.whl", hash = "sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8", size = 22361112 }, +] + +[[package]] +name = "markdown" +version = "3.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/54/28/3af612670f82f4c056911fbbbb42760255801b3068c48de792d354ff4472/markdown-3.7.tar.gz", hash = "sha256:2ae2471477cfd02dbbf038d5d9bc226d40def84b4fe2986e49b59b6b472bbed2", size = 357086 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/08/83871f3c50fc983b88547c196d11cf8c3340e37c32d2e9d6152abe2c61f7/Markdown-3.7-py3-none-any.whl", hash = "sha256:7eb6df5690b81a1d7942992c97fad2938e956e79df20cbc6186e9c3a77b1c803", size = 106349 }, +] + +[[package]] +name = "markdown-it-py" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/38/71/3b932df36c1a044d397a1f92d1cf91ee0a503d91e470cbd670aa66b07ed0/markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb", size = 74596 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1", size = 87528 }, +] + +[[package]] +name = "markupsafe" +version = "3.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/97/5d42485e71dfc078108a86d6de8fa46db44a1a9295e89c5d6d4a06e23a62/markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0", size = 20537 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/22/09/d1f21434c97fc42f09d290cbb6350d44eb12f09cc62c9476effdb33a18aa/MarkupSafe-3.0.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9778bd8ab0a994ebf6f84c2b949e65736d5575320a17ae8984a77fab08db94cf", size = 14274 }, + { url = "https://files.pythonhosted.org/packages/6b/b0/18f76bba336fa5aecf79d45dcd6c806c280ec44538b3c13671d49099fdd0/MarkupSafe-3.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:846ade7b71e3536c4e56b386c2a47adf5741d2d8b94ec9dc3e92e5e1ee1e2225", size = 12348 }, + { url = "https://files.pythonhosted.org/packages/e0/25/dd5c0f6ac1311e9b40f4af06c78efde0f3b5cbf02502f8ef9501294c425b/MarkupSafe-3.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c99d261bd2d5f6b59325c92c73df481e05e57f19837bdca8413b9eac4bd8028", size = 24149 }, + { url = "https://files.pythonhosted.org/packages/f3/f0/89e7aadfb3749d0f52234a0c8c7867877876e0a20b60e2188e9850794c17/MarkupSafe-3.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e17c96c14e19278594aa4841ec148115f9c7615a47382ecb6b82bd8fea3ab0c8", size = 23118 }, + { url = "https://files.pythonhosted.org/packages/d5/da/f2eeb64c723f5e3777bc081da884b414671982008c47dcc1873d81f625b6/MarkupSafe-3.0.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:88416bd1e65dcea10bc7569faacb2c20ce071dd1f87539ca2ab364bf6231393c", size = 22993 }, + { url = "https://files.pythonhosted.org/packages/da/0e/1f32af846df486dce7c227fe0f2398dc7e2e51d4a370508281f3c1c5cddc/MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2181e67807fc2fa785d0592dc2d6206c019b9502410671cc905d132a92866557", size = 24178 }, + { url = "https://files.pythonhosted.org/packages/c4/f6/bb3ca0532de8086cbff5f06d137064c8410d10779c4c127e0e47d17c0b71/MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:52305740fe773d09cffb16f8ed0427942901f00adedac82ec8b67752f58a1b22", size = 23319 }, + { url = "https://files.pythonhosted.org/packages/a2/82/8be4c96ffee03c5b4a034e60a31294daf481e12c7c43ab8e34a1453ee48b/MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ad10d3ded218f1039f11a75f8091880239651b52e9bb592ca27de44eed242a48", size = 23352 }, + { url = "https://files.pythonhosted.org/packages/51/ae/97827349d3fcffee7e184bdf7f41cd6b88d9919c80f0263ba7acd1bbcb18/MarkupSafe-3.0.2-cp312-cp312-win32.whl", hash = "sha256:0f4ca02bea9a23221c0182836703cbf8930c5e9454bacce27e767509fa286a30", size = 15097 }, + { url = "https://files.pythonhosted.org/packages/c1/80/a61f99dc3a936413c3ee4e1eecac96c0da5ed07ad56fd975f1a9da5bc630/MarkupSafe-3.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:8e06879fc22a25ca47312fbe7c8264eb0b662f6db27cb2d3bbbc74b1df4b9b87", size = 15601 }, + { url = "https://files.pythonhosted.org/packages/83/0e/67eb10a7ecc77a0c2bbe2b0235765b98d164d81600746914bebada795e97/MarkupSafe-3.0.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ba9527cdd4c926ed0760bc301f6728ef34d841f405abf9d4f959c478421e4efd", size = 14274 }, + { url = "https://files.pythonhosted.org/packages/2b/6d/9409f3684d3335375d04e5f05744dfe7e9f120062c9857df4ab490a1031a/MarkupSafe-3.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f8b3d067f2e40fe93e1ccdd6b2e1d16c43140e76f02fb1319a05cf2b79d99430", size = 12352 }, + { url = "https://files.pythonhosted.org/packages/d2/f5/6eadfcd3885ea85fe2a7c128315cc1bb7241e1987443d78c8fe712d03091/MarkupSafe-3.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:569511d3b58c8791ab4c2e1285575265991e6d8f8700c7be0e88f86cb0672094", size = 24122 }, + { url = "https://files.pythonhosted.org/packages/0c/91/96cf928db8236f1bfab6ce15ad070dfdd02ed88261c2afafd4b43575e9e9/MarkupSafe-3.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15ab75ef81add55874e7ab7055e9c397312385bd9ced94920f2802310c930396", size = 23085 }, + { url = "https://files.pythonhosted.org/packages/c2/cf/c9d56af24d56ea04daae7ac0940232d31d5a8354f2b457c6d856b2057d69/MarkupSafe-3.0.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f3818cb119498c0678015754eba762e0d61e5b52d34c8b13d770f0719f7b1d79", size = 22978 }, + { url = "https://files.pythonhosted.org/packages/2a/9f/8619835cd6a711d6272d62abb78c033bda638fdc54c4e7f4272cf1c0962b/MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:cdb82a876c47801bb54a690c5ae105a46b392ac6099881cdfb9f6e95e4014c6a", size = 24208 }, + { url = "https://files.pythonhosted.org/packages/f9/bf/176950a1792b2cd2102b8ffeb5133e1ed984547b75db47c25a67d3359f77/MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:cabc348d87e913db6ab4aa100f01b08f481097838bdddf7c7a84b7575b7309ca", size = 23357 }, + { url = "https://files.pythonhosted.org/packages/ce/4f/9a02c1d335caabe5c4efb90e1b6e8ee944aa245c1aaaab8e8a618987d816/MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:444dcda765c8a838eaae23112db52f1efaf750daddb2d9ca300bcae1039adc5c", size = 23344 }, + { url = "https://files.pythonhosted.org/packages/ee/55/c271b57db36f748f0e04a759ace9f8f759ccf22b4960c270c78a394f58be/MarkupSafe-3.0.2-cp313-cp313-win32.whl", hash = "sha256:bcf3e58998965654fdaff38e58584d8937aa3096ab5354d493c77d1fdd66d7a1", size = 15101 }, + { url = "https://files.pythonhosted.org/packages/29/88/07df22d2dd4df40aba9f3e402e6dc1b8ee86297dddbad4872bd5e7b0094f/MarkupSafe-3.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:e6a2a455bd412959b57a172ce6328d2dd1f01cb2135efda2e4576e8a23fa3b0f", size = 15603 }, + { url = "https://files.pythonhosted.org/packages/62/6a/8b89d24db2d32d433dffcd6a8779159da109842434f1dd2f6e71f32f738c/MarkupSafe-3.0.2-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:b5a6b3ada725cea8a5e634536b1b01c30bcdcd7f9c6fff4151548d5bf6b3a36c", size = 14510 }, + { url = "https://files.pythonhosted.org/packages/7a/06/a10f955f70a2e5a9bf78d11a161029d278eeacbd35ef806c3fd17b13060d/MarkupSafe-3.0.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:a904af0a6162c73e3edcb969eeeb53a63ceeb5d8cf642fade7d39e7963a22ddb", size = 12486 }, + { url = "https://files.pythonhosted.org/packages/34/cf/65d4a571869a1a9078198ca28f39fba5fbb910f952f9dbc5220afff9f5e6/MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4aa4e5faecf353ed117801a068ebab7b7e09ffb6e1d5e412dc852e0da018126c", size = 25480 }, + { url = "https://files.pythonhosted.org/packages/0c/e3/90e9651924c430b885468b56b3d597cabf6d72be4b24a0acd1fa0e12af67/MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0ef13eaeee5b615fb07c9a7dadb38eac06a0608b41570d8ade51c56539e509d", size = 23914 }, + { url = "https://files.pythonhosted.org/packages/66/8c/6c7cf61f95d63bb866db39085150df1f2a5bd3335298f14a66b48e92659c/MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d16a81a06776313e817c951135cf7340a3e91e8c1ff2fac444cfd75fffa04afe", size = 23796 }, + { url = "https://files.pythonhosted.org/packages/bb/35/cbe9238ec3f47ac9a7c8b3df7a808e7cb50fe149dc7039f5f454b3fba218/MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6381026f158fdb7c72a168278597a5e3a5222e83ea18f543112b2662a9b699c5", size = 25473 }, + { url = "https://files.pythonhosted.org/packages/e6/32/7621a4382488aa283cc05e8984a9c219abad3bca087be9ec77e89939ded9/MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:3d79d162e7be8f996986c064d1c7c817f6df3a77fe3d6859f6f9e7be4b8c213a", size = 24114 }, + { url = "https://files.pythonhosted.org/packages/0d/80/0985960e4b89922cb5a0bac0ed39c5b96cbc1a536a99f30e8c220a996ed9/MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:131a3c7689c85f5ad20f9f6fb1b866f402c445b220c19fe4308c0b147ccd2ad9", size = 24098 }, + { url = "https://files.pythonhosted.org/packages/82/78/fedb03c7d5380df2427038ec8d973587e90561b2d90cd472ce9254cf348b/MarkupSafe-3.0.2-cp313-cp313t-win32.whl", hash = "sha256:ba8062ed2cf21c07a9e295d5b8a2a5ce678b913b45fdf68c32d95d6c1291e0b6", size = 15208 }, + { url = "https://files.pythonhosted.org/packages/4f/65/6079a46068dfceaeabb5dcad6d674f5f5c61a6fa5673746f42a9f4c233b3/MarkupSafe-3.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f", size = 15739 }, +] + +[[package]] +name = "mccabe" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/ff/0ffefdcac38932a54d2b5eed4e0ba8a408f215002cd178ad1df0f2806ff8/mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325", size = 9658 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/1a/1f68f9ba0c207934b35b86a8ca3aad8395a3d6dd7921c0686e23853ff5a9/mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e", size = 7350 }, +] + +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 }, +] + +[[package]] +name = "memoized-property" +version = "1.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/70/db/23f8b5d86c9385299586c2469b58087f658f58eaeb414be0fd64cfd054e1/memoized-property-1.0.3.tar.gz", hash = "sha256:4be4d0209944b9b9b678dae9d7e312249fe2e6fb8bdc9bdaa1da4de324f0fcf5", size = 5011 } + +[[package]] +name = "mhcflurry" +version = "2.1.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "appdirs" }, + { name = "mhcgnomes" }, + { name = "pandas" }, + { name = "pyyaml" }, + { name = "scikit-learn" }, + { name = "six" }, + { name = "tensorflow" }, + { name = "tf-keras" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/41/3a/284f8069f9c694c0cd17079137c346de2bcd8b27d2d9925a9fc6e7725916/mhcflurry-2.1.4.tar.gz", hash = "sha256:5f8a4e743bf6e5a17ad4e6fe79be7755b45a5005d048fbf3d8eb90762236c68e", size = 144468 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ae/f1/18b213ad5397fd298887e67757c1ce58fa79f8119f0796a70ad3977f789e/mhcflurry-2.1.4-py3-none-any.whl", hash = "sha256:9e8168c8d205e22cc3f4a06abb4d0e0f188a0b594dc4a3fc3fb16a3f2a7345e6", size = 140832 }, +] + +[[package]] +name = "mhcgnomes" +version = "1.8.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "pandas" }, + { name = "pyyaml" }, + { name = "serializable" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a8/41/7b11a2fdee588025619868866ee9121235c5bb56bfddb4773d7c176bc4bb/mhcgnomes-1.8.6.tar.gz", hash = "sha256:d32b886d9cd58ed0e45d4cb3da83a439b1b68b59790ae04985711e489aa5e264", size = 708992 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b0/3e/4fa67920f80300828bbdcc5fe97eeb33958d6825f60a9b4f57ef392e8bd4/mhcgnomes-1.8.6-py3-none-any.whl", hash = "sha256:f40cc7e0ba44dd8f1e733ba0525a8db62e016a0fbd1591a6fe2298ccee64dda0", size = 103723 }, +] + +[[package]] +name = "mhcnames" +version = "0.4.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/30/30/29cb727cacfcb6ab465df723b4b632a47b4989f526f5ade3b0592763d34e/mhcnames-0.4.8.tar.gz", hash = "sha256:0a18de129eaa4bf8ce802e5e2d856806639ab17b392688ea13dc20ec6d3cc8a2", size = 13787 } + +[[package]] +name = "mhctools" +version = "1.9.1" +source = { git = "https://github.com/hammerlab/mhctools.git#868ed09b4dfcab18aed563727d65bca3408476ea" } +dependencies = [ + { name = "mhcflurry" }, + { name = "mhcnames" }, + { name = "numpy" }, + { name = "pandas" }, + { name = "pyensembl" }, + { name = "sercol" }, + { name = "varcode" }, +] + +[[package]] +name = "ml-dtypes" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fd/15/76f86faa0902836cc133939732f7611ace68cf54148487a99c539c272dc8/ml_dtypes-0.4.1.tar.gz", hash = "sha256:fad5f2de464fd09127e49b7fd1252b9006fb43d2edc1ff112d390c324af5ca7a", size = 692594 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/1a/99e924f12e4b62139fbac87419698c65f956d58de0dbfa7c028fa5b096aa/ml_dtypes-0.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:827d3ca2097085cf0355f8fdf092b888890bb1b1455f52801a2d7756f056f54b", size = 405077 }, + { url = "https://files.pythonhosted.org/packages/8f/8c/7b610bd500617854c8cc6ed7c8cfb9d48d6a5c21a1437a36a4b9bc8a3598/ml_dtypes-0.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:772426b08a6172a891274d581ce58ea2789cc8abc1c002a27223f314aaf894e7", size = 2181554 }, + { url = "https://files.pythonhosted.org/packages/c7/c6/f89620cecc0581dc1839e218c4315171312e46c62a62da6ace204bda91c0/ml_dtypes-0.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:126e7d679b8676d1a958f2651949fbfa182832c3cd08020d8facd94e4114f3e9", size = 2160488 }, + { url = "https://files.pythonhosted.org/packages/ae/11/a742d3c31b2cc8557a48efdde53427fd5f9caa2fa3c9c27d826e78a66f51/ml_dtypes-0.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:df0fb650d5c582a9e72bb5bd96cfebb2cdb889d89daff621c8fbc60295eba66c", size = 127462 }, +] + +[[package]] +name = "mock" +version = "5.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/66/ab/41d09a46985ead5839d8be987acda54b5bb93f713b3969cc0be4f81c455b/mock-5.1.0.tar.gz", hash = "sha256:5e96aad5ccda4718e0a229ed94b2024df75cc2d55575ba5762d31f5767b8767d", size = 80232 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6b/20/471f41173930550f279ccb65596a5ac19b9ac974a8d93679bcd3e0c31498/mock-5.1.0-py3-none-any.whl", hash = "sha256:18c694e5ae8a208cdb3d2c20a993ca1a7b0efa258c247a1e565150f477f83744", size = 30938 }, +] + +[[package]] +name = "namex" +version = "0.0.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9d/48/d275cdb6216c6bb4f9351675795a0b48974e138f16b1ffe0252c1f8faa28/namex-0.0.8.tar.gz", hash = "sha256:32a50f6c565c0bb10aa76298c959507abdc0e850efe085dc38f3440fcb3aa90b", size = 6623 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/59/7854fbfb59f8ae35483ce93493708be5942ebb6328cd85b3a609df629736/namex-0.0.8-py3-none-any.whl", hash = "sha256:7ddb6c2bb0e753a311b7590f84f6da659dd0c05e65cb89d519d54c0a250c0487", size = 5806 }, +] + +[[package]] +name = "numpy" +version = "1.26.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/65/6e/09db70a523a96d25e115e71cc56a6f9031e7b8cd166c1ac8438307c14058/numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010", size = 15786129 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/12/8f2020a8e8b8383ac0177dc9570aad031a3beb12e38847f7129bacd96228/numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218", size = 20335901 }, + { url = "https://files.pythonhosted.org/packages/75/5b/ca6c8bd14007e5ca171c7c03102d17b4f4e0ceb53957e8c44343a9546dcc/numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b", size = 13685868 }, + { url = "https://files.pythonhosted.org/packages/79/f8/97f10e6755e2a7d027ca783f63044d5b1bc1ae7acb12afe6a9b4286eac17/numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b", size = 13925109 }, + { url = "https://files.pythonhosted.org/packages/0f/50/de23fde84e45f5c4fda2488c759b69990fd4512387a8632860f3ac9cd225/numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed", size = 17950613 }, + { url = "https://files.pythonhosted.org/packages/4c/0c/9c603826b6465e82591e05ca230dfc13376da512b25ccd0894709b054ed0/numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a", size = 13572172 }, + { url = "https://files.pythonhosted.org/packages/76/8c/2ba3902e1a0fc1c74962ea9bb33a534bb05984ad7ff9515bf8d07527cadd/numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0", size = 17786643 }, + { url = "https://files.pythonhosted.org/packages/28/4a/46d9e65106879492374999e76eb85f87b15328e06bd1550668f79f7b18c6/numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110", size = 5677803 }, + { url = "https://files.pythonhosted.org/packages/16/2e/86f24451c2d530c88daf997cb8d6ac622c1d40d19f5a031ed68a4b73a374/numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818", size = 15517754 }, +] + +[[package]] +name = "opt-einsum" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8c/b9/2ac072041e899a52f20cf9510850ff58295003aa75525e58343591b0cbfb/opt_einsum-3.4.0.tar.gz", hash = "sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac", size = 63004 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932 }, +] + +[[package]] +name = "optree" +version = "0.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9f/33/cc6673a6141cb1546d94c3b8b396d115a338022875c485d8b2219853851b/optree-0.13.0.tar.gz", hash = "sha256:1ea493cde8c60f7950ccbd682bd67e787bf67ed2251d6d3e9ad7471b72d37538", size = 153449 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/c8/7cd83c6ffefaf92b8cb90200a286131880395eefd6ea39ecf4200e282158/optree-0.13.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:abeb8acc83d168063b70168ccf8dfd55f5a7ce50f9af2ca025c41285781ecdd4", size = 564255 }, + { url = "https://files.pythonhosted.org/packages/30/f7/28ed168baa10a4d38b38a9b6af5be499a6d03c84bd29f4900c6932260cd7/optree-0.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4771266f05e99e94312add38d45bf97a4d98449aeab100f5c658c521152eb5e5", size = 305077 }, + { url = "https://files.pythonhosted.org/packages/02/d8/3b2a4cc1a699a1aa953c43d913262bdc5cf0794abe19dff857ba860fd91d/optree-0.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc95c1d0c7acd534184bf3ba243a454e0942e4a7c8b9edd32d939fc15e33d753", size = 332200 }, + { url = "https://files.pythonhosted.org/packages/f1/03/b7b2beb97d385d400598c2932ab5e4dc6aace1f4f178a7016a3191172107/optree-0.13.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e48491e042f956d4232ebc138e07074100878c0080e3ba10af4c2db1ba4df9f", size = 376866 }, + { url = "https://files.pythonhosted.org/packages/69/51/a1ec0e804282ca358b888984eeb542e012aa946a993cf5fa7e25e33a6844/optree-0.13.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8e001d9c902e98912503eca66c93d4b4b22f5071e4ab777f4db9e140f35288f4", size = 374265 }, + { url = "https://files.pythonhosted.org/packages/57/7a/9f56f9ee27f1fb4a5fbfa74f6b17f9e34741aabf7d97a9ff762dfbf7bf09/optree-0.13.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:87870346278f46a8c22866ff48716590be35b4aea16e1373e695fb6442c28c41", size = 345345 }, + { url = "https://files.pythonhosted.org/packages/8a/d2/aa2eb567867bfbd9fb418619321bd94db5df2591e527dd405a21540c337b/optree-0.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7797c54a35e9d89b4664ec7d542745b87b5ffa9c1201c1062fdcd488eb583390", size = 362962 }, + { url = "https://files.pythonhosted.org/packages/93/bd/362ccdd09b4df8d66bb394c5affc55992cb4a07a770bee6ef83da28a0e0a/optree-0.13.0-cp312-cp312-win32.whl", hash = "sha256:fc90a5373c92f4a9babb4c40fe148516f52160c0ba803bc9b2f936367f2f7437", size = 254773 }, + { url = "https://files.pythonhosted.org/packages/76/04/4a130228ecc0fc1d7e839fd310fa3b655927815ff3368a887b1138ebf9dc/optree-0.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:1bc65743e8edb29e902cab894d1c4665a8fd6f8d10f75db68a2cef6c7246fa5c", size = 283486 }, + { url = "https://files.pythonhosted.org/packages/78/89/26d89a49173bedabcbd7d4f7a65e1950e1e0b08c56e4d9d9a8c4154bf890/optree-0.13.0-cp312-cp312-win_arm64.whl", hash = "sha256:de2729e1e4ae47a07ac3c70ff977ed1ebe19e7b44d5089075c94f7a9a2dc6f4f", size = 283486 }, + { url = "https://files.pythonhosted.org/packages/11/dd/ef557f1b2ccaddcc905f87b778c9f681b1c8571801f18c70c798c4d6f1d1/optree-0.13.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dda6efabd0621f53eb46a3789ec89c6fd2c90dfb57aebfce3fcda6eab9ed6a7e", size = 572235 }, + { url = "https://files.pythonhosted.org/packages/4c/7f/d266b929fe45b1966b97940c2e476c49fdebb10bb8007e52151e29731d6d/optree-0.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5de8da9bbdd08b6200244ee818cd15d1da0f2b06ac926dba0e686260bac7fd40", size = 308048 }, + { url = "https://files.pythonhosted.org/packages/ac/86/3b17056b654e5b610ff0df2dbd51d61679f7dd73256fa9c924a980c83296/optree-0.13.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca1e4854134023ba687a7abf45ed3355f773ca7198b6895d88a89030446a9f2e", size = 335440 }, + { url = "https://files.pythonhosted.org/packages/93/67/76decf96ccbf257756284a3ac03a4d09bd73e0c527e0cbe0e1c080789cf5/optree-0.13.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1ac5343e921ce21f8f10f91158ad6404a1488c1cc22ddfa6b34cfb9d997cebd", size = 380357 }, + { url = "https://files.pythonhosted.org/packages/45/02/6986e5c549ea89652e43553d36c2412aefcede1ff0f97a953eb4b38ee203/optree-0.13.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e282212ddf3aafb10ca6ca223772e06ea3c31687c9cae192467b8e0a7dafbfc", size = 378690 }, + { url = "https://files.pythonhosted.org/packages/9c/af/bd901f029ec3cdebbad23e43794a96df33f2b805fb705ba4c6f042f91afa/optree-0.13.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:24fcd4cb659bcd9b675bc3401950de891b32a047c4787857fb870cd515fcc315", size = 349577 }, + { url = "https://files.pythonhosted.org/packages/5e/d1/3336fe90a29b5237b437def9fd2446a69ca0a234dd80d68111a687884bd6/optree-0.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d735a7d2d2e2eb9a88d932d35b335c10fae9038034f381b6d437dafed46497e", size = 367375 }, + { url = "https://files.pythonhosted.org/packages/7a/c7/073b6d014963e5621c4d515e6a86541787ef7920ab324e8db96ea755e71d/optree-0.13.0-cp313-cp313-win32.whl", hash = "sha256:ef01e79224f0ee6cf2ca642884f0bc04e446227b96dc576c312717eb33552d57", size = 257814 }, + { url = "https://files.pythonhosted.org/packages/bd/41/92cd15786619bd1d3c233661ecdfbe790a26610a0979122a44e1972b9c66/optree-0.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:d3f61fb669b36c1a714346b18c9c488ad33a58049b7b229785c241de18c005d7", size = 285900 }, + { url = "https://files.pythonhosted.org/packages/81/47/f4c77affa551a46e2019a06db95cb4319293630383e9d7e18fe1f47435c5/optree-0.13.0-cp313-cp313-win_arm64.whl", hash = "sha256:695b3f1aab50519230e3d8d86abaedaadf91af105b569cce3b8ebe0dc612b312", size = 285914 }, + { url = "https://files.pythonhosted.org/packages/ef/92/2a4a1ee91d074f3902813c31152c006a1b4f9acce2feafc4c9e96f020d74/optree-0.13.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:1318434b0740a2325c197e191e6dd53d9df0a8ac0338c67d58b476aad9d07829", size = 674076 }, + { url = "https://files.pythonhosted.org/packages/84/67/6585d6f5d143d89afcc85bd78f70d19fb9295bab37d7ae5d21ae5aa690ca/optree-0.13.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:d58c6e8d4c4fa4e0c31bc4b876960ccba94eb5fcfb045f2b064ce55707034be9", size = 354957 }, + { url = "https://files.pythonhosted.org/packages/84/66/6e22fd91c3c0efae0c16f23d90a1e1f48813d1d377ebc829a134fa9b37ec/optree-0.13.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6a290ba771cc9004f9fc194d23ab11ee4aae71550ca874c3dc985af5b5f910b", size = 355144 }, + { url = "https://files.pythonhosted.org/packages/5a/db/529170be3bd962b1aa583bb9515f918dd45ccc48f6b47a46727b5640599e/optree-0.13.0-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c95488ecbab2916de094e68f2a2c55c9475b2e979c03d91a6cd3565f9e5ff2f9", size = 398205 }, + { url = "https://files.pythonhosted.org/packages/e6/a4/709302007b6845608c63c853cde863876922304761b239fdd8527a297fd7/optree-0.13.0-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8f76a65ff322b3d47af2a23f60409d6d8f184804da551c734e355834e69c0dfb", size = 398120 }, + { url = "https://files.pythonhosted.org/packages/4c/3f/f74ae5e59a9cc092c6847aecfd46776ea544211228b6fed6f0ae04a53df3/optree-0.13.0-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:58cc303f982fb0f23644b7f8e98b4f64b0d031365fcc2284da896e96493176d2", size = 368580 }, + { url = "https://files.pythonhosted.org/packages/b4/10/34437943c94e0c493c3c413c57af4091a245dcdbd57d84acc6ac0979931a/optree-0.13.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6866b6e4154303dc7c48c7ca3b867a8ce31d469334b67976dfc0513455aa1ca0", size = 384994 }, + { url = "https://files.pythonhosted.org/packages/56/36/84fed312a5aa1865dbe4292324926d153fd92a7d915740b14d900176e79f/optree-0.13.0-cp313-cp313t-win32.whl", hash = "sha256:f5ce67f81fe3d7ca5fed8fdaf93a762a63e1d125e20e425ca7200f9e54a3e3a6", size = 285895 }, + { url = "https://files.pythonhosted.org/packages/45/e6/a043f21d04674544207248c4d7fa1348a61be0fbdc64794325384ed53bc6/optree-0.13.0-cp313-cp313t-win_amd64.whl", hash = "sha256:0008cd39169c1fc10870528b2decfea8b79e61042c12d65a964f3b1cf41cc37d", size = 321826 }, + { url = "https://files.pythonhosted.org/packages/0c/ef/0f4ebc4061022ffca10637936f7d7b9a45c8db7593c1f06d32641a7bc16f/optree-0.13.0-cp313-cp313t-win_arm64.whl", hash = "sha256:539962675b547957c64b52b7f82178febb9c0f2d47438b810bbc23cfdcf84821", size = 321828 }, +] + +[[package]] +name = "packaging" +version = "24.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/51/65/50db4dda066951078f0a96cf12f4b9ada6e4b811516bf0262c0f4f7064d4/packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002", size = 148788 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/08/aa/cc0199a5f0ad350994d660967a8efb233fe0416e4639146c089643407ce6/packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124", size = 53985 }, +] + +[[package]] +name = "pandas" +version = "2.2.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "python-dateutil" }, + { name = "pytz" }, + { name = "tzdata" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9c/d6/9f8431bacc2e19dca897724cd097b1bb224a6ad5433784a44b587c7c13af/pandas-2.2.3.tar.gz", hash = "sha256:4f18ba62b61d7e192368b84517265a99b4d7ee8912f8708660fb4a366cc82667", size = 4399213 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/a3/fb2734118db0af37ea7433f57f722c0a56687e14b14690edff0cdb4b7e58/pandas-2.2.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b1d432e8d08679a40e2a6d8b2f9770a5c21793a6f9f47fdd52c5ce1948a5a8a9", size = 12529893 }, + { url = "https://files.pythonhosted.org/packages/e1/0c/ad295fd74bfac85358fd579e271cded3ac969de81f62dd0142c426b9da91/pandas-2.2.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a5a1595fe639f5988ba6a8e5bc9649af3baf26df3998a0abe56c02609392e0a4", size = 11363475 }, + { url = "https://files.pythonhosted.org/packages/c6/2a/4bba3f03f7d07207481fed47f5b35f556c7441acddc368ec43d6643c5777/pandas-2.2.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5de54125a92bb4d1c051c0659e6fcb75256bf799a732a87184e5ea503965bce3", size = 15188645 }, + { url = "https://files.pythonhosted.org/packages/38/f8/d8fddee9ed0d0c0f4a2132c1dfcf0e3e53265055da8df952a53e7eaf178c/pandas-2.2.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fffb8ae78d8af97f849404f21411c95062db1496aeb3e56f146f0355c9989319", size = 12739445 }, + { url = "https://files.pythonhosted.org/packages/20/e8/45a05d9c39d2cea61ab175dbe6a2de1d05b679e8de2011da4ee190d7e748/pandas-2.2.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6dfcb5ee8d4d50c06a51c2fffa6cff6272098ad6540aed1a76d15fb9318194d8", size = 16359235 }, + { url = "https://files.pythonhosted.org/packages/1d/99/617d07a6a5e429ff90c90da64d428516605a1ec7d7bea494235e1c3882de/pandas-2.2.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:062309c1b9ea12a50e8ce661145c6aab431b1e99530d3cd60640e255778bd43a", size = 14056756 }, + { url = "https://files.pythonhosted.org/packages/29/d4/1244ab8edf173a10fd601f7e13b9566c1b525c4f365d6bee918e68381889/pandas-2.2.3-cp312-cp312-win_amd64.whl", hash = "sha256:59ef3764d0fe818125a5097d2ae867ca3fa64df032331b7e0917cf5d7bf66b13", size = 11504248 }, + { url = "https://files.pythonhosted.org/packages/64/22/3b8f4e0ed70644e85cfdcd57454686b9057c6c38d2f74fe4b8bc2527214a/pandas-2.2.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f00d1345d84d8c86a63e476bb4955e46458b304b9575dcf71102b5c705320015", size = 12477643 }, + { url = "https://files.pythonhosted.org/packages/e4/93/b3f5d1838500e22c8d793625da672f3eec046b1a99257666c94446969282/pandas-2.2.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3508d914817e153ad359d7e069d752cdd736a247c322d932eb89e6bc84217f28", size = 11281573 }, + { url = "https://files.pythonhosted.org/packages/f5/94/6c79b07f0e5aab1dcfa35a75f4817f5c4f677931d4234afcd75f0e6a66ca/pandas-2.2.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:22a9d949bfc9a502d320aa04e5d02feab689d61da4e7764b62c30b991c42c5f0", size = 15196085 }, + { url = "https://files.pythonhosted.org/packages/e8/31/aa8da88ca0eadbabd0a639788a6da13bb2ff6edbbb9f29aa786450a30a91/pandas-2.2.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3a255b2c19987fbbe62a9dfd6cff7ff2aa9ccab3fc75218fd4b7530f01efa24", size = 12711809 }, + { url = "https://files.pythonhosted.org/packages/ee/7c/c6dbdb0cb2a4344cacfb8de1c5808ca885b2e4dcfde8008266608f9372af/pandas-2.2.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:800250ecdadb6d9c78eae4990da62743b857b470883fa27f652db8bdde7f6659", size = 16356316 }, + { url = "https://files.pythonhosted.org/packages/57/b7/8b757e7d92023b832869fa8881a992696a0bfe2e26f72c9ae9f255988d42/pandas-2.2.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6374c452ff3ec675a8f46fd9ab25c4ad0ba590b71cf0656f8b6daa5202bca3fb", size = 14022055 }, + { url = "https://files.pythonhosted.org/packages/3b/bc/4b18e2b8c002572c5a441a64826252ce5da2aa738855747247a971988043/pandas-2.2.3-cp313-cp313-win_amd64.whl", hash = "sha256:61c5ad4043f791b61dd4752191d9f07f0ae412515d59ba8f005832a532f8736d", size = 11481175 }, + { url = "https://files.pythonhosted.org/packages/76/a3/a5d88146815e972d40d19247b2c162e88213ef51c7c25993942c39dbf41d/pandas-2.2.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:3b71f27954685ee685317063bf13c7709a7ba74fc996b84fc6821c59b0f06468", size = 12615650 }, + { url = "https://files.pythonhosted.org/packages/9c/8c/f0fd18f6140ddafc0c24122c8a964e48294acc579d47def376fef12bcb4a/pandas-2.2.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:38cf8125c40dae9d5acc10fa66af8ea6fdf760b2714ee482ca691fc66e6fcb18", size = 11290177 }, + { url = "https://files.pythonhosted.org/packages/ed/f9/e995754eab9c0f14c6777401f7eece0943840b7a9fc932221c19d1abee9f/pandas-2.2.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ba96630bc17c875161df3818780af30e43be9b166ce51c9a18c1feae342906c2", size = 14651526 }, + { url = "https://files.pythonhosted.org/packages/25/b0/98d6ae2e1abac4f35230aa756005e8654649d305df9a28b16b9ae4353bff/pandas-2.2.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1db71525a1538b30142094edb9adc10be3f3e176748cd7acc2240c2f2e5aa3a4", size = 11871013 }, + { url = "https://files.pythonhosted.org/packages/cc/57/0f72a10f9db6a4628744c8e8f0df4e6e21de01212c7c981d31e50ffc8328/pandas-2.2.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:15c0e1e02e93116177d29ff83e8b1619c93ddc9c49083f237d4312337a61165d", size = 15711620 }, + { url = "https://files.pythonhosted.org/packages/ab/5f/b38085618b950b79d2d9164a711c52b10aefc0ae6833b96f626b7021b2ed/pandas-2.2.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ad5b65698ab28ed8d7f18790a0dc58005c7629f227be9ecc1072aa74c0c1d43a", size = 13098436 }, +] + +[[package]] +name = "platformdirs" +version = "4.3.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/13/fc/128cc9cb8f03208bdbf93d3aa862e16d376844a14f9a0ce5cf4507372de4/platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907", size = 21302 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/a6/bc1012356d8ece4d66dd75c4b9fc6c1f6650ddd5991e421177d9f8f671be/platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb", size = 18439 }, +] + +[[package]] +name = "pluggy" +version = "1.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, +] + +[[package]] +name = "polars" +version = "0.20.31" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/10/cb/447d8ba0d38df42bf247ade1fb7ad1ba61f09a95144ee8c4ab0314d38703/polars-0.20.31.tar.gz", hash = "sha256:00f62dec6bf43a4e2a5db58b99bf0e79699fe761c80ae665868eaea5168f3bbb", size = 3666354 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/74/ca08d8b5d067159541c4419f0cbd0b474bd44a89f97e79e0b4b3fd5b24b5/polars-0.20.31-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:86454ade5ed302bbf87f145cfcb1b14f7a5765a9440e448659e1f3dba6ac4e79", size = 27645039 }, + { url = "https://files.pythonhosted.org/packages/d0/71/984ed5f67c824c9b547665454ee438e0540a1ce2e8eca4d2021eeaf826aa/polars-0.20.31-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:67f2fe842262b7e1b9371edad21b760f6734d28b74c78dda88dff1bf031b9499", size = 24750031 }, + { url = "https://files.pythonhosted.org/packages/90/7d/7541e559d7fce232ba34340b0953cac9af2344853d675dc2de01a4d3abc7/polars-0.20.31-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:24b82441f93409e0e8abd6f427b029db102f02b8de328cee9a680f84b84e3736", size = 28792870 }, + { url = "https://files.pythonhosted.org/packages/72/6a/6bf5da56542ae976140dd30be950149146c361eb8dd6471fdb6d50ae7581/polars-0.20.31-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:87f43bce4d41abf8c8c5658d881e4b8378e5c61010a696bfea8b4106b908e916", size = 26994077 }, + { url = "https://files.pythonhosted.org/packages/37/3e/d8f460c420254b094df5b2fa24e1d5571611540309eb66dad46405fb9b47/polars-0.20.31-cp38-abi3-win_amd64.whl", hash = "sha256:2d7567c9fd9d3b9aa93387ca9880d9e8f7acea3c0a0555c03d8c0c2f0715d43c", size = 28847550 }, +] + +[[package]] +name = "prediction-tools-fastapi" +version = "0.1.0" +source = { virtual = "." } +dependencies = [ + { name = "fastapi", extra = ["standard"] }, + { name = "httpx" }, + { name = "mhctools" }, + { name = "pandas" }, + { name = "pydantic-settings" }, + { name = "requests" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pylint-pydantic" }, + { name = "pytest" }, + { name = "python-dotenv" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "fastapi", extras = ["standard"], specifier = ">=0.115.0" }, + { name = "httpx", specifier = ">=0.27.2" }, + { name = "mhctools", git = "https://github.com/hammerlab/mhctools.git" }, + { name = "pandas", specifier = ">=2.2.3" }, + { name = "pydantic-settings", specifier = ">=2.5.2" }, + { name = "requests", specifier = ">=2.32.3" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pylint-pydantic", specifier = ">=0.3.2" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "python-dotenv", specifier = ">=1.0.1" }, + { name = "ruff", specifier = ">=0.6.8" }, +] + +[[package]] +name = "progressbar33" +version = "2.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/71/fc/7c8e01f41a6e671d7b11be470eeb3d15339c75ce5559935f3f55890eec6b/progressbar33-2.4.tar.gz", hash = "sha256:51fe0d9b3b4023db2f983eeccdfc8c9846b84db8443b9bee002c7f58f4376eff", size = 10184 } + +[[package]] +name = "protobuf" +version = "4.25.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/67/dd/48d5fdb68ec74d70fabcc252e434492e56f70944d9f17b6a15e3746d2295/protobuf-4.25.5.tar.gz", hash = "sha256:7f8249476b4a9473645db7f8ab42b02fe1488cbe5fb72fddd445e0665afd8584", size = 380315 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/35/1b3c5a5e6107859c4ca902f4fbb762e48599b78129a05d20684fef4a4d04/protobuf-4.25.5-cp310-abi3-win32.whl", hash = "sha256:5e61fd921603f58d2f5acb2806a929b4675f8874ff5f330b7d6f7e2e784bbcd8", size = 392457 }, + { url = "https://files.pythonhosted.org/packages/a7/ad/bf3f358e90b7e70bf7fb520702cb15307ef268262292d3bdb16ad8ebc815/protobuf-4.25.5-cp310-abi3-win_amd64.whl", hash = "sha256:4be0571adcbe712b282a330c6e89eae24281344429ae95c6d85e79e84780f5ea", size = 413449 }, + { url = "https://files.pythonhosted.org/packages/51/49/d110f0a43beb365758a252203c43eaaad169fe7749da918869a8c991f726/protobuf-4.25.5-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:b2fde3d805354df675ea4c7c6338c1aecd254dfc9925e88c6d31a2bcb97eb173", size = 394248 }, + { url = "https://files.pythonhosted.org/packages/c6/ab/0f384ca0bc6054b1a7b6009000ab75d28a5506e4459378b81280ae7fd358/protobuf-4.25.5-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:919ad92d9b0310070f8356c24b855c98df2b8bd207ebc1c0c6fcc9ab1e007f3d", size = 293717 }, + { url = "https://files.pythonhosted.org/packages/05/a6/094a2640be576d760baa34c902dcb8199d89bce9ed7dd7a6af74dcbbd62d/protobuf-4.25.5-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:fe14e16c22be926d3abfcb500e60cab068baf10b542b8c858fa27e098123e331", size = 294635 }, + { url = "https://files.pythonhosted.org/packages/33/90/f198a61df8381fb43ae0fe81b3d2718e8dcc51ae8502c7657ab9381fbc4f/protobuf-4.25.5-py3-none-any.whl", hash = "sha256:0aebecb809cae990f8129ada5ca273d9d670b76d9bfc9b1809f0a9c02b7dbf41", size = 156467 }, +] + +[[package]] +name = "pyarrow" +version = "14.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d7/8b/d18b7eb6fb22e5ed6ffcbc073c85dae635778dbd1270a6cf5d750b031e84/pyarrow-14.0.2.tar.gz", hash = "sha256:36cef6ba12b499d864d1def3e990f97949e0b79400d08b7cf74504ffbd3eb025", size = 1063645 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/5b/d8ab6c20c43b598228710e4e4a6cba03a01f6faa3d08afff9ce76fd0fd47/pyarrow-14.0.2-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:c87824a5ac52be210d32906c715f4ed7053d0180c1060ae3ff9b7e560f53f944", size = 26819585 }, + { url = "https://files.pythonhosted.org/packages/2d/29/bed2643d0dd5e9570405244a61f6db66c7f4704a6e9ce313f84fa5a3675a/pyarrow-14.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a25eb2421a58e861f6ca91f43339d215476f4fe159eca603c55950c14f378cc5", size = 23965222 }, + { url = "https://files.pythonhosted.org/packages/2a/34/da464632e59a8cdd083370d69e6c14eae30221acb284f671c6bc9273fadd/pyarrow-14.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c1da70d668af5620b8ba0a23f229030a4cd6c5f24a616a146f30d2386fec422", size = 35942036 }, + { url = "https://files.pythonhosted.org/packages/a8/ff/cbed4836d543b29f00d2355af67575c934999ff1d43e3f438ab0b1b394f1/pyarrow-14.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2cc61593c8e66194c7cdfae594503e91b926a228fba40b5cf25cc593563bcd07", size = 38089266 }, + { url = "https://files.pythonhosted.org/packages/38/41/345011cb831d3dbb2dab762fc244c745a5df94b199223a99af52a5f7dff6/pyarrow-14.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:78ea56f62fb7c0ae8ecb9afdd7893e3a7dbeb0b04106f5c08dbb23f9c0157591", size = 35404468 }, + { url = "https://files.pythonhosted.org/packages/fd/af/2fc23ca2068ff02068d8dabf0fb85b6185df40ec825973470e613dbd8790/pyarrow-14.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:37c233ddbce0c67a76c0985612fef27c0c92aef9413cf5aa56952f359fcb7379", size = 38003134 }, + { url = "https://files.pythonhosted.org/packages/95/1f/9d912f66a87e3864f694e000977a6a70a644ea560289eac1d733983f215d/pyarrow-14.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:e4b123ad0f6add92de898214d404e488167b87b5dd86e9a434126bc2b7a5578d", size = 25043754 }, +] + +[[package]] +name = "pydantic" +version = "2.9.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-types" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a9/b7/d9e3f12af310e1120c21603644a1cd86f59060e040ec5c3a80b8f05fae30/pydantic-2.9.2.tar.gz", hash = "sha256:d155cef71265d1e9807ed1c32b4c8deec042a44a50a4188b25ac67ecd81a9c0f", size = 769917 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/df/e4/ba44652d562cbf0bf320e0f3810206149c8a4e99cdbf66da82e97ab53a15/pydantic-2.9.2-py3-none-any.whl", hash = "sha256:f048cec7b26778210e28a0459867920654d48e5e62db0958433636cde4254f12", size = 434928 }, +] + +[[package]] +name = "pydantic-core" +version = "2.23.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e2/aa/6b6a9b9f8537b872f552ddd46dd3da230367754b6f707b8e1e963f515ea3/pydantic_core-2.23.4.tar.gz", hash = "sha256:2584f7cf844ac4d970fba483a717dbe10c1c1c96a969bf65d61ffe94df1b2863", size = 402156 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/74/7b/8e315f80666194b354966ec84b7d567da77ad927ed6323db4006cf915f3f/pydantic_core-2.23.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f3e0da4ebaef65158d4dfd7d3678aad692f7666877df0002b8a522cdf088f231", size = 1856459 }, + { url = "https://files.pythonhosted.org/packages/14/de/866bdce10ed808323d437612aca1ec9971b981e1c52e5e42ad9b8e17a6f6/pydantic_core-2.23.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f69a8e0b033b747bb3e36a44e7732f0c99f7edd5cea723d45bc0d6e95377ffee", size = 1770007 }, + { url = "https://files.pythonhosted.org/packages/dc/69/8edd5c3cd48bb833a3f7ef9b81d7666ccddd3c9a635225214e044b6e8281/pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:723314c1d51722ab28bfcd5240d858512ffd3116449c557a1336cbe3919beb87", size = 1790245 }, + { url = "https://files.pythonhosted.org/packages/80/33/9c24334e3af796ce80d2274940aae38dd4e5676298b4398eff103a79e02d/pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bb2802e667b7051a1bebbfe93684841cc9351004e2badbd6411bf357ab8d5ac8", size = 1801260 }, + { url = "https://files.pythonhosted.org/packages/a5/6f/e9567fd90104b79b101ca9d120219644d3314962caa7948dd8b965e9f83e/pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d18ca8148bebe1b0a382a27a8ee60350091a6ddaf475fa05ef50dc35b5df6327", size = 1996872 }, + { url = "https://files.pythonhosted.org/packages/2d/ad/b5f0fe9e6cfee915dd144edbd10b6e9c9c9c9d7a56b69256d124b8ac682e/pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:33e3d65a85a2a4a0dc3b092b938a4062b1a05f3a9abde65ea93b233bca0e03f2", size = 2661617 }, + { url = "https://files.pythonhosted.org/packages/06/c8/7d4b708f8d05a5cbfda3243aad468052c6e99de7d0937c9146c24d9f12e9/pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:128585782e5bfa515c590ccee4b727fb76925dd04a98864182b22e89a4e6ed36", size = 2071831 }, + { url = "https://files.pythonhosted.org/packages/89/4d/3079d00c47f22c9a9a8220db088b309ad6e600a73d7a69473e3a8e5e3ea3/pydantic_core-2.23.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:68665f4c17edcceecc112dfed5dbe6f92261fb9d6054b47d01bf6371a6196126", size = 1917453 }, + { url = "https://files.pythonhosted.org/packages/e9/88/9df5b7ce880a4703fcc2d76c8c2d8eb9f861f79d0c56f4b8f5f2607ccec8/pydantic_core-2.23.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:20152074317d9bed6b7a95ade3b7d6054845d70584216160860425f4fbd5ee9e", size = 1968793 }, + { url = "https://files.pythonhosted.org/packages/e3/b9/41f7efe80f6ce2ed3ee3c2dcfe10ab7adc1172f778cc9659509a79518c43/pydantic_core-2.23.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9261d3ce84fa1d38ed649c3638feefeae23d32ba9182963e465d58d62203bd24", size = 2116872 }, + { url = "https://files.pythonhosted.org/packages/63/08/b59b7a92e03dd25554b0436554bf23e7c29abae7cce4b1c459cd92746811/pydantic_core-2.23.4-cp312-none-win32.whl", hash = "sha256:4ba762ed58e8d68657fc1281e9bb72e1c3e79cc5d464be146e260c541ec12d84", size = 1738535 }, + { url = "https://files.pythonhosted.org/packages/88/8d/479293e4d39ab409747926eec4329de5b7129beaedc3786eca070605d07f/pydantic_core-2.23.4-cp312-none-win_amd64.whl", hash = "sha256:97df63000f4fea395b2824da80e169731088656d1818a11b95f3b173747b6cd9", size = 1917992 }, + { url = "https://files.pythonhosted.org/packages/ad/ef/16ee2df472bf0e419b6bc68c05bf0145c49247a1095e85cee1463c6a44a1/pydantic_core-2.23.4-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:7530e201d10d7d14abce4fb54cfe5b94a0aefc87da539d0346a484ead376c3cc", size = 1856143 }, + { url = "https://files.pythonhosted.org/packages/da/fa/bc3dbb83605669a34a93308e297ab22be82dfb9dcf88c6cf4b4f264e0a42/pydantic_core-2.23.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:df933278128ea1cd77772673c73954e53a1c95a4fdf41eef97c2b779271bd0bd", size = 1770063 }, + { url = "https://files.pythonhosted.org/packages/4e/48/e813f3bbd257a712303ebdf55c8dc46f9589ec74b384c9f652597df3288d/pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cb3da3fd1b6a5d0279a01877713dbda118a2a4fc6f0d821a57da2e464793f05", size = 1790013 }, + { url = "https://files.pythonhosted.org/packages/b4/e0/56eda3a37929a1d297fcab1966db8c339023bcca0b64c5a84896db3fcc5c/pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:42c6dcb030aefb668a2b7009c85b27f90e51e6a3b4d5c9bc4c57631292015b0d", size = 1801077 }, + { url = "https://files.pythonhosted.org/packages/04/be/5e49376769bfbf82486da6c5c1683b891809365c20d7c7e52792ce4c71f3/pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:696dd8d674d6ce621ab9d45b205df149399e4bb9aa34102c970b721554828510", size = 1996782 }, + { url = "https://files.pythonhosted.org/packages/bc/24/e3ee6c04f1d58cc15f37bcc62f32c7478ff55142b7b3e6d42ea374ea427c/pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2971bb5ffe72cc0f555c13e19b23c85b654dd2a8f7ab493c262071377bfce9f6", size = 2661375 }, + { url = "https://files.pythonhosted.org/packages/c1/f8/11a9006de4e89d016b8de74ebb1db727dc100608bb1e6bbe9d56a3cbbcce/pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8394d940e5d400d04cad4f75c0598665cbb81aecefaca82ca85bd28264af7f9b", size = 2071635 }, + { url = "https://files.pythonhosted.org/packages/7c/45/bdce5779b59f468bdf262a5bc9eecbae87f271c51aef628d8c073b4b4b4c/pydantic_core-2.23.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0dff76e0602ca7d4cdaacc1ac4c005e0ce0dcfe095d5b5259163a80d3a10d327", size = 1916994 }, + { url = "https://files.pythonhosted.org/packages/d8/fa/c648308fe711ee1f88192cad6026ab4f925396d1293e8356de7e55be89b5/pydantic_core-2.23.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7d32706badfe136888bdea71c0def994644e09fff0bfe47441deaed8e96fdbc6", size = 1968877 }, + { url = "https://files.pythonhosted.org/packages/16/16/b805c74b35607d24d37103007f899abc4880923b04929547ae68d478b7f4/pydantic_core-2.23.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ed541d70698978a20eb63d8c5d72f2cc6d7079d9d90f6b50bad07826f1320f5f", size = 2116814 }, + { url = "https://files.pythonhosted.org/packages/d1/58/5305e723d9fcdf1c5a655e6a4cc2a07128bf644ff4b1d98daf7a9dbf57da/pydantic_core-2.23.4-cp313-none-win32.whl", hash = "sha256:3d5639516376dce1940ea36edf408c554475369f5da2abd45d44621cb616f769", size = 1738360 }, + { url = "https://files.pythonhosted.org/packages/a5/ae/e14b0ff8b3f48e02394d8acd911376b7b66e164535687ef7dc24ea03072f/pydantic_core-2.23.4-cp313-none-win_amd64.whl", hash = "sha256:5a1504ad17ba4210df3a045132a7baeeba5a200e930f57512ee02909fc5c4cb5", size = 1919411 }, +] + +[[package]] +name = "pydantic-settings" +version = "2.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "python-dotenv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6c/66/5f1a9da10675bfb3b9da52f5b689c77e0a5612263fcce510cfac3e99a168/pydantic_settings-2.6.0.tar.gz", hash = "sha256:44a1804abffac9e6a30372bb45f6cafab945ef5af25e66b1c634c01dd39e0188", size = 75232 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/34/19/26bb6bdb9fdad5f0dfce538780814084fb667b4bc37fcb28459c14b8d3b5/pydantic_settings-2.6.0-py3-none-any.whl", hash = "sha256:4a819166f119b74d7f8c765196b165f95cc7487ce58ea27dec8a5a26be0970e0", size = 28578 }, +] + +[[package]] +name = "pyensembl" +version = "2.3.13" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "datacache" }, + { name = "gtfparse" }, + { name = "memoized-property" }, + { name = "pylint" }, + { name = "serializable" }, + { name = "tinytimer" }, + { name = "typechecks" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/85/94/253fa59f1f9e7fe3b7d34818fd4fe08f857da241234bf18c6ee1b9496afe/pyensembl-2.3.13.tar.gz", hash = "sha256:c70ce760f68fe2a6be871db44e53ce1d4d1227f2ce0578c6b291d5a89f5d1832", size = 60784 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/4f/d8ea7378ffd03a7e71f6e750cb5fa7e8721a0e30f04af8ae60e704d59e96/pyensembl-2.3.13-py3-none-any.whl", hash = "sha256:46989e4eb3c3436ac2a8b02f17e999439c04ca2db0926ce6d5535a9a0a00bfce", size = 55952 }, +] + +[[package]] +name = "pygments" +version = "2.18.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/62/8336eff65bcbc8e4cb5d05b55faf041285951b6e80f33e2bff2024788f31/pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199", size = 4891905 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f7/3f/01c8b82017c199075f8f788d0d906b9ffbbc5a47dc9918a945e13d5a2bda/pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a", size = 1205513 }, +] + +[[package]] +name = "pylint" +version = "2.17.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "astroid" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "dill" }, + { name = "isort" }, + { name = "mccabe" }, + { name = "platformdirs" }, + { name = "tomlkit" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/e9/21f9ce3e4b81eef011be070a29f8a5c193e2488ed8713a898baa4e8b3362/pylint-2.17.7.tar.gz", hash = "sha256:f4fcac7ae74cfe36bc8451e931d8438e4a476c20314b1101c458ad0f05191fad", size = 434994 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b4/49/cea450a83079445a84f16050e571a7c383d3f474b13c5caedfebd4e35def/pylint-2.17.7-py3-none-any.whl", hash = "sha256:27a8d4c7ddc8c2f8c18aa0050148f89ffc09838142193fdbe98f172781a3ff87", size = 537178 }, +] + +[[package]] +name = "pylint-plugin-utils" +version = "0.8.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pylint" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4b/d2/3b9728910bc69232ec38d8fb7053c03c887bfe7e6e170649b683dd351750/pylint_plugin_utils-0.8.2.tar.gz", hash = "sha256:d3cebf68a38ba3fba23a873809155562571386d4c1b03e5b4c4cc26c3eee93e4", size = 10674 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/ee/49d11aee31061bcc1d2726bd8334a2883ddcdbde7d7744ed6b3bd11704ed/pylint_plugin_utils-0.8.2-py3-none-any.whl", hash = "sha256:ae11664737aa2effbf26f973a9e0b6779ab7106ec0adc5fe104b0907ca04e507", size = 11171 }, +] + +[[package]] +name = "pylint-pydantic" +version = "0.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "pylint" }, + { name = "pylint-plugin-utils" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/80/34b429c6534be99ef3d6d20bd794b26fda0682d38e2d57f85df258beaac2/pylint_pydantic-0.3.2-py3-none-any.whl", hash = "sha256:e5cec02370aa68ac8eff138e5d573b0ac049bab864e9a6c3a9057cf043440aa1", size = 15951 }, +] + +[[package]] +name = "pytest" +version = "8.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8b/6c/62bbd536103af674e227c41a8f3dcd022d591f6eed5facb5a0f31ee33bbc/pytest-8.3.3.tar.gz", hash = "sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181", size = 1442487 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 }, +] + +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892 }, +] + +[[package]] +name = "python-dotenv" +version = "1.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bc/57/e84d88dfe0aec03b7a2d4327012c1627ab5f03652216c63d49846d7a6c58/python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca", size = 39115 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863 }, +] + +[[package]] +name = "python-multipart" +version = "0.0.12" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/16/6e/7ecfe1366b9270f7f475c76fcfa28812493a6a1abd489b2433851a444f4f/python_multipart-0.0.12.tar.gz", hash = "sha256:045e1f98d719c1ce085ed7f7e1ef9d8ccc8c02ba02b5566d5f7521410ced58cb", size = 35713 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f5/0b/c316262244abea7481f95f1e91d7575f3dfcf6455d56d1ffe9839c582eb1/python_multipart-0.0.12-py3-none-any.whl", hash = "sha256:43dcf96cf65888a9cd3423544dd0d75ac10f7aa0c3c28a175bbcd00c9ce1aebf", size = 23246 }, +] + +[[package]] +name = "pytz" +version = "2024.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3a/31/3c70bf7603cc2dca0f19bdc53b4537a797747a58875b552c8c413d963a3f/pytz-2024.2.tar.gz", hash = "sha256:2aa355083c50a0f93fa581709deac0c9ad65cca8a9e9beac660adcbd493c798a", size = 319692 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/c3/005fcca25ce078d2cc29fd559379817424e94885510568bc1bc53d7d5846/pytz-2024.2-py2.py3-none-any.whl", hash = "sha256:31c7c1817eb7fae7ca4b8c7ee50c72f93aa2dd863de768e1ef4245d426aa0725", size = 508002 }, +] + +[[package]] +name = "pyvcf3" +version = "1.0.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "setuptools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/28/ee/5d0b0bd20762f1126b34cb15dce973fe4dbd1baf17d7ee621b18e336c0e5/PyVCF3-1.0.3.tar.gz", hash = "sha256:4b16d71c8b97010487e2c939fb4d5707b7bbfa4e2b313df9dba3e372c5ba031d", size = 977597 } + +[[package]] +name = "pyyaml" +version = "6.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/54/ed/79a089b6be93607fa5cdaedf301d7dfb23af5f25c398d5ead2525b063e17/pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e", size = 130631 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/86/0c/c581167fc46d6d6d7ddcfb8c843a4de25bdd27e4466938109ca68492292c/PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab", size = 183873 }, + { url = "https://files.pythonhosted.org/packages/a8/0c/38374f5bb272c051e2a69281d71cba6fdb983413e6758b84482905e29a5d/PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725", size = 173302 }, + { url = "https://files.pythonhosted.org/packages/c3/93/9916574aa8c00aa06bbac729972eb1071d002b8e158bd0e83a3b9a20a1f7/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5", size = 739154 }, + { url = "https://files.pythonhosted.org/packages/95/0f/b8938f1cbd09739c6da569d172531567dbcc9789e0029aa070856f123984/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425", size = 766223 }, + { url = "https://files.pythonhosted.org/packages/b9/2b/614b4752f2e127db5cc206abc23a8c19678e92b23c3db30fc86ab731d3bd/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476", size = 767542 }, + { url = "https://files.pythonhosted.org/packages/d4/00/dd137d5bcc7efea1836d6264f049359861cf548469d18da90cd8216cf05f/PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48", size = 731164 }, + { url = "https://files.pythonhosted.org/packages/c9/1f/4f998c900485e5c0ef43838363ba4a9723ac0ad73a9dc42068b12aaba4e4/PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b", size = 756611 }, + { url = "https://files.pythonhosted.org/packages/df/d1/f5a275fdb252768b7a11ec63585bc38d0e87c9e05668a139fea92b80634c/PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4", size = 140591 }, + { url = "https://files.pythonhosted.org/packages/0c/e8/4f648c598b17c3d06e8753d7d13d57542b30d56e6c2dedf9c331ae56312e/PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8", size = 156338 }, + { url = "https://files.pythonhosted.org/packages/ef/e3/3af305b830494fa85d95f6d95ef7fa73f2ee1cc8ef5b495c7c3269fb835f/PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba", size = 181309 }, + { url = "https://files.pythonhosted.org/packages/45/9f/3b1c20a0b7a3200524eb0076cc027a970d320bd3a6592873c85c92a08731/PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1", size = 171679 }, + { url = "https://files.pythonhosted.org/packages/7c/9a/337322f27005c33bcb656c655fa78325b730324c78620e8328ae28b64d0c/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133", size = 733428 }, + { url = "https://files.pythonhosted.org/packages/a3/69/864fbe19e6c18ea3cc196cbe5d392175b4cf3d5d0ac1403ec3f2d237ebb5/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484", size = 763361 }, + { url = "https://files.pythonhosted.org/packages/04/24/b7721e4845c2f162d26f50521b825fb061bc0a5afcf9a386840f23ea19fa/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5", size = 759523 }, + { url = "https://files.pythonhosted.org/packages/2b/b2/e3234f59ba06559c6ff63c4e10baea10e5e7df868092bf9ab40e5b9c56b6/PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc", size = 726660 }, + { url = "https://files.pythonhosted.org/packages/fe/0f/25911a9f080464c59fab9027482f822b86bf0608957a5fcc6eaac85aa515/PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652", size = 751597 }, + { url = "https://files.pythonhosted.org/packages/14/0d/e2c3b43bbce3cf6bd97c840b46088a3031085179e596d4929729d8d68270/PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183", size = 140527 }, + { url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446 }, +] + +[[package]] +name = "requests" +version = "2.32.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "idna" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/63/70/2bf7780ad2d390a8d301ad0b550f1581eadbd9a20f896afe06353c2a2913/requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760", size = 131218 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 }, +] + +[[package]] +name = "rich" +version = "13.9.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/aa/9e/1784d15b057b0075e5136445aaea92d23955aad2c93eaede673718a40d95/rich-13.9.2.tar.gz", hash = "sha256:51a2c62057461aaf7152b4d611168f93a9fc73068f8ded2790f29fe2b5366d0c", size = 222843 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/67/91/5474b84e505a6ccc295b2d322d90ff6aa0746745717839ee0c5fb4fdcceb/rich-13.9.2-py3-none-any.whl", hash = "sha256:8c82a3d3f8dcfe9e734771313e606b39d8247bb6b826e196f4914b333b743cf1", size = 242117 }, +] + +[[package]] +name = "ruff" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/c7/f3367d1da5d568192968c5c9e7f3d51fb317b9ac04828493b23d8fce8ce6/ruff-0.7.0.tar.gz", hash = "sha256:47a86360cf62d9cd53ebfb0b5eb0e882193fc191c6d717e8bef4462bc3b9ea2b", size = 3146645 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/59/a0275a0913f3539498d116046dd679cd657fe3b7caf5afe1733319414932/ruff-0.7.0-py3-none-linux_armv6l.whl", hash = "sha256:0cdf20c2b6ff98e37df47b2b0bd3a34aaa155f59a11182c1303cce79be715628", size = 10434007 }, + { url = "https://files.pythonhosted.org/packages/cd/94/da0ba5f956d04c90dd899209904210600009dcda039ce840d83eb4298c7d/ruff-0.7.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:496494d350c7fdeb36ca4ef1c9f21d80d182423718782222c29b3e72b3512737", size = 10048066 }, + { url = "https://files.pythonhosted.org/packages/57/1d/e5cc149ecc46e4f203403a79ccd170fad52d316f98b87d0f63b1945567db/ruff-0.7.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:214b88498684e20b6b2b8852c01d50f0651f3cc6118dfa113b4def9f14faaf06", size = 9711389 }, + { url = "https://files.pythonhosted.org/packages/05/67/fb7ea2c869c539725a16c5bc294e9aa34f8b1b6fe702f1d173a5da517c2b/ruff-0.7.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:630fce3fefe9844e91ea5bbf7ceadab4f9981f42b704fae011bb8efcaf5d84be", size = 10755174 }, + { url = "https://files.pythonhosted.org/packages/5f/f0/13703bc50536a0613ea3dce991116e5f0917a1f05528c6ab738b33c08d3f/ruff-0.7.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:211d877674e9373d4bb0f1c80f97a0201c61bcd1e9d045b6e9726adc42c156aa", size = 10196040 }, + { url = "https://files.pythonhosted.org/packages/99/c1/77b04ab20324ab03d333522ee55fb0f1c38e3ca0d326b4905f82ce6b6c70/ruff-0.7.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:194d6c46c98c73949a106425ed40a576f52291c12bc21399eb8f13a0f7073495", size = 11033684 }, + { url = "https://files.pythonhosted.org/packages/f2/97/f463334dc4efeea3551cd109163df15561c18a1c3ec13d51643740fd36ba/ruff-0.7.0-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:82c2579b82b9973a110fab281860403b397c08c403de92de19568f32f7178598", size = 11803700 }, + { url = "https://files.pythonhosted.org/packages/b4/f8/a31d40c4bb92933d376a53e7c5d0245d9b27841357e4820e96d38f54b480/ruff-0.7.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9af971fe85dcd5eaed8f585ddbc6bdbe8c217fb8fcf510ea6bca5bdfff56040e", size = 11347848 }, + { url = "https://files.pythonhosted.org/packages/83/62/0c133b35ddaf91c65c30a56718b80bdef36bfffc35684d29e3a4878e0ea3/ruff-0.7.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b641c7f16939b7d24b7bfc0be4102c56562a18281f84f635604e8a6989948914", size = 12480632 }, + { url = "https://files.pythonhosted.org/packages/46/96/464058dd1d980014fb5aa0a1254e78799efb3096fc7a4823cd66a1621276/ruff-0.7.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d71672336e46b34e0c90a790afeac8a31954fd42872c1f6adaea1dff76fd44f9", size = 10941919 }, + { url = "https://files.pythonhosted.org/packages/a0/f7/bda37ec77986a435dde44e1f59374aebf4282a5fa9cf17735315b847141f/ruff-0.7.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:ab7d98c7eed355166f367597e513a6c82408df4181a937628dbec79abb2a1fe4", size = 10745519 }, + { url = "https://files.pythonhosted.org/packages/c2/33/5f77fc317027c057b61a848020a47442a1cbf12e592df0e41e21f4d0f3bd/ruff-0.7.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:1eb54986f770f49edb14f71d33312d79e00e629a57387382200b1ef12d6a4ef9", size = 10284872 }, + { url = "https://files.pythonhosted.org/packages/ff/50/98aec292bc9537f640b8d031c55f3414bf15b6ed13b3e943fed75ac927b9/ruff-0.7.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:dc452ba6f2bb9cf8726a84aa877061a2462afe9ae0ea1d411c53d226661c601d", size = 10600334 }, + { url = "https://files.pythonhosted.org/packages/f2/85/12607ae3201423a179b8cfadc7cb1e57d02cd0135e45bd0445acb4cef327/ruff-0.7.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:4b406c2dce5be9bad59f2de26139a86017a517e6bcd2688da515481c05a2cb11", size = 11017333 }, + { url = "https://files.pythonhosted.org/packages/d4/7f/3b85a56879e705d5f46ec14daf8a439fca05c3081720fe3dc3209100922d/ruff-0.7.0-py3-none-win32.whl", hash = "sha256:f6c968509f767776f524a8430426539587d5ec5c662f6addb6aa25bc2e8195ec", size = 8570962 }, + { url = "https://files.pythonhosted.org/packages/39/9f/c5ee2b40d377354dabcc23cff47eb299de4b4d06d345068f8f8cc1eadac8/ruff-0.7.0-py3-none-win_amd64.whl", hash = "sha256:ff4aabfbaaba880e85d394603b9e75d32b0693152e16fa659a3064a85df7fce2", size = 9365544 }, + { url = "https://files.pythonhosted.org/packages/89/8b/ee1509f60148cecba644aa718f6633216784302458340311898aaf0b1bed/ruff-0.7.0-py3-none-win_arm64.whl", hash = "sha256:10842f69c245e78d6adec7e1db0a7d9ddc2fff0621d730e61657b64fa36f207e", size = 8695763 }, +] + +[[package]] +name = "scikit-learn" +version = "1.5.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "joblib" }, + { name = "numpy" }, + { name = "scipy" }, + { name = "threadpoolctl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/37/59/44985a2bdc95c74e34fef3d10cb5d93ce13b0e2a7baefffe1b53853b502d/scikit_learn-1.5.2.tar.gz", hash = "sha256:b4237ed7b3fdd0a4882792e68ef2545d5baa50aca3bb45aa7df468138ad8f94d", size = 7001680 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/db/b485c1ac54ff3bd9e7e6b39d3cc6609c4c76a65f52ab0a7b22b6c3ab0e9d/scikit_learn-1.5.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f932a02c3f4956dfb981391ab24bda1dbd90fe3d628e4b42caef3e041c67707a", size = 12110344 }, + { url = "https://files.pythonhosted.org/packages/54/1a/7deb52fa23aebb855431ad659b3c6a2e1709ece582cb3a63d66905e735fe/scikit_learn-1.5.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:3b923d119d65b7bd555c73be5423bf06c0105678ce7e1f558cb4b40b0a5502b1", size = 11033502 }, + { url = "https://files.pythonhosted.org/packages/a1/32/4a7a205b14c11225609b75b28402c196e4396ac754dab6a81971b811781c/scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd", size = 12085794 }, + { url = "https://files.pythonhosted.org/packages/c6/29/044048c5e911373827c0e1d3051321b9183b2a4f8d4e2f11c08fcff83f13/scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6", size = 12945797 }, + { url = "https://files.pythonhosted.org/packages/aa/ce/c0b912f2f31aeb1b756a6ba56bcd84dd1f8a148470526a48515a3f4d48cd/scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1", size = 10985467 }, + { url = "https://files.pythonhosted.org/packages/a4/50/8891028437858cc510e13578fe7046574a60c2aaaa92b02d64aac5b1b412/scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5", size = 12025584 }, + { url = "https://files.pythonhosted.org/packages/d2/79/17feef8a1c14149436083bec0e61d7befb4812e272d5b20f9d79ea3e9ab1/scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908", size = 10959795 }, + { url = "https://files.pythonhosted.org/packages/b1/c8/f08313f9e2e656bd0905930ae8bf99a573ea21c34666a813b749c338202f/scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3", size = 12077302 }, + { url = "https://files.pythonhosted.org/packages/a7/48/fbfb4dc72bed0fe31fe045fb30e924909ad03f717c36694351612973b1a9/scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12", size = 13002811 }, + { url = "https://files.pythonhosted.org/packages/a5/e7/0c869f9e60d225a77af90d2aefa7a4a4c0e745b149325d1450f0f0ce5399/scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f", size = 10951354 }, +] + +[[package]] +name = "scipy" +version = "1.14.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/62/11/4d44a1f274e002784e4dbdb81e0ea96d2de2d1045b2132d5af62cc31fd28/scipy-1.14.1.tar.gz", hash = "sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417", size = 58620554 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/04/2bdacc8ac6387b15db6faa40295f8bd25eccf33f1f13e68a72dc3c60a99e/scipy-1.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d", size = 39128781 }, + { url = "https://files.pythonhosted.org/packages/c8/53/35b4d41f5fd42f5781dbd0dd6c05d35ba8aa75c84ecddc7d44756cd8da2e/scipy-1.14.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07", size = 29939542 }, + { url = "https://files.pythonhosted.org/packages/66/67/6ef192e0e4d77b20cc33a01e743b00bc9e68fb83b88e06e636d2619a8767/scipy-1.14.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5", size = 23148375 }, + { url = "https://files.pythonhosted.org/packages/f6/32/3a6dedd51d68eb7b8e7dc7947d5d841bcb699f1bf4463639554986f4d782/scipy-1.14.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc", size = 25578573 }, + { url = "https://files.pythonhosted.org/packages/f0/5a/efa92a58dc3a2898705f1dc9dbaf390ca7d4fba26d6ab8cfffb0c72f656f/scipy-1.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310", size = 35319299 }, + { url = "https://files.pythonhosted.org/packages/8e/ee/8a26858ca517e9c64f84b4c7734b89bda8e63bec85c3d2f432d225bb1886/scipy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066", size = 40849331 }, + { url = "https://files.pythonhosted.org/packages/a5/cd/06f72bc9187840f1c99e1a8750aad4216fc7dfdd7df46e6280add14b4822/scipy-1.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1", size = 42544049 }, + { url = "https://files.pythonhosted.org/packages/aa/7d/43ab67228ef98c6b5dd42ab386eae2d7877036970a0d7e3dd3eb47a0d530/scipy-1.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f", size = 44521212 }, + { url = "https://files.pythonhosted.org/packages/50/ef/ac98346db016ff18a6ad7626a35808f37074d25796fd0234c2bb0ed1e054/scipy-1.14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79", size = 39091068 }, + { url = "https://files.pythonhosted.org/packages/b9/cc/70948fe9f393b911b4251e96b55bbdeaa8cca41f37c26fd1df0232933b9e/scipy-1.14.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e", size = 29875417 }, + { url = "https://files.pythonhosted.org/packages/3b/2e/35f549b7d231c1c9f9639f9ef49b815d816bf54dd050da5da1c11517a218/scipy-1.14.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73", size = 23084508 }, + { url = "https://files.pythonhosted.org/packages/3f/d6/b028e3f3e59fae61fb8c0f450db732c43dd1d836223a589a8be9f6377203/scipy-1.14.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e", size = 25503364 }, + { url = "https://files.pythonhosted.org/packages/a7/2f/6c142b352ac15967744d62b165537a965e95d557085db4beab2a11f7943b/scipy-1.14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d", size = 35292639 }, + { url = "https://files.pythonhosted.org/packages/56/46/2449e6e51e0d7c3575f289f6acb7f828938eaab8874dbccfeb0cd2b71a27/scipy-1.14.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e", size = 40798288 }, + { url = "https://files.pythonhosted.org/packages/32/cd/9d86f7ed7f4497c9fd3e39f8918dd93d9f647ba80d7e34e4946c0c2d1a7c/scipy-1.14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06", size = 42524647 }, + { url = "https://files.pythonhosted.org/packages/f5/1b/6ee032251bf4cdb0cc50059374e86a9f076308c1512b61c4e003e241efb7/scipy-1.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84", size = 44469524 }, +] + +[[package]] +name = "sercol" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pandas" }, + { name = "serializable" }, + { name = "simplejson" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fa/ed/416326b528f9db32d9d5b9229ef9f7572b1a57b36650d782f08472d3efce/sercol-1.0.0.tar.gz", hash = "sha256:6507f5efcee4596eb77c45f1580ff562533ef5d115918a6a28bc4eac09ec191c", size = 8726 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f3/60/87f95cede57c2a2c58b073967660a6504b3125fb45013e1999f11118f574/sercol-1.0.0-py3-none-any.whl", hash = "sha256:15ffcc4927d3bf5a3d333f68ad3c40fe7fdfb42f8257e15185b4ac160087d8eb", size = 9015 }, +] + +[[package]] +name = "serializable" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "simplejson" }, + { name = "typechecks" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/17/29/98d529dbd22b079ad29a80531cef271894ed2aaf4f3310b8c3b102c95e4e/serializable-0.4.1.tar.gz", hash = "sha256:882098d79253d38591a3d4d7d5d78052fce3ba4d29d64c1704d73f1a19d066d8", size = 12873 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bd/a6/736a99c39baa85cf51922526525300fb732f2d31138c28ae1634370bf8bf/serializable-0.4.1-py3-none-any.whl", hash = "sha256:d1f09fade42cc49f9cdf138cd8d2d3515beb1605507ae26983ea812bae77b3e2", size = 12947 }, +] + +[[package]] +name = "setuptools" +version = "75.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/07/37/b31be7e4b9f13b59cde9dcaeff112d401d49e0dc5b37ed4a9fc8fb12f409/setuptools-75.2.0.tar.gz", hash = "sha256:753bb6ebf1f465a1912e19ed1d41f403a79173a9acf66a42e7e6aec45c3c16ec", size = 1350308 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/31/2d/90165d51ecd38f9a02c6832198c13a4e48652485e2ccf863ebb942c531b6/setuptools-75.2.0-py3-none-any.whl", hash = "sha256:a7fcb66f68b4d9e8e66b42f9876150a3371558f98fa32222ffaa5bced76406f8", size = 1249825 }, +] + +[[package]] +name = "shellingham" +version = "1.5.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755 }, +] + +[[package]] +name = "simplejson" +version = "3.19.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3d/29/085111f19717f865eceaf0d4397bf3e76b08d60428b076b64e2a1903706d/simplejson-3.19.3.tar.gz", hash = "sha256:8e086896c36210ab6050f2f9f095a5f1e03c83fa0e7f296d6cba425411364680", size = 85237 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/15/513fea93fafbdd4993eacfcb762965b2ff3d29e618c029e2956174d68c4b/simplejson-3.19.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:66a0399e21c2112acacfebf3d832ebe2884f823b1c7e6d1363f2944f1db31a99", size = 92921 }, + { url = "https://files.pythonhosted.org/packages/a4/4f/998a907ae1a6c104dc0ee48aa248c2478490152808d34d8e07af57f396c3/simplejson-3.19.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6ef9383c5e05f445be60f1735c1816163c874c0b1ede8bb4390aff2ced34f333", size = 75311 }, + { url = "https://files.pythonhosted.org/packages/db/44/acd6122201e927451869d45952b9ab1d3025cdb5e61548d286d08fbccc08/simplejson-3.19.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:42e5acf80d4d971238d4df97811286a044d720693092b20a56d5e56b7dcc5d09", size = 74964 }, + { url = "https://files.pythonhosted.org/packages/27/ca/d0a1e8f16e1bbdc0b8c6d88166f45f565ed7285f53928cfef3b6ce78f14d/simplejson-3.19.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0b0efc7279d768db7c74d3d07f0b5c81280d16ae3fb14e9081dc903e8360771", size = 150106 }, + { url = "https://files.pythonhosted.org/packages/63/59/0554b78cf26c98e2b9cae3f44723bd72c2394e2afec1a14eedc6211f7187/simplejson-3.19.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0552eb06e7234da892e1d02365cd2b7b2b1f8233aa5aabdb2981587b7cc92ea0", size = 158347 }, + { url = "https://files.pythonhosted.org/packages/b2/fe/9f30890352e431e8508cc569912d3322147d3e7e4f321e48c0adfcb4c97d/simplejson-3.19.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5bf6a3b9a7d7191471b464fe38f684df10eb491ec9ea454003edb45a011ab187", size = 148456 }, + { url = "https://files.pythonhosted.org/packages/37/e3/663a09542ee021d4131162f7a164cb2e7f04ef48433a67591738afbf12ea/simplejson-3.19.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7017329ca8d4dca94ad5e59f496e5fc77630aecfc39df381ffc1d37fb6b25832", size = 152190 }, + { url = "https://files.pythonhosted.org/packages/31/20/4e0c4d35e10ff6465003bec304316d822a559a1c38c66ef6892ca199c207/simplejson-3.19.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:67a20641afebf4cfbcff50061f07daad1eace6e7b31d7622b6fa2c40d43900ba", size = 149846 }, + { url = "https://files.pythonhosted.org/packages/08/7a/46e2e072cac3987cbb05946f25167f0ad2fe536748e7405953fd6661a486/simplejson-3.19.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:dd6a7dabcc4c32daf601bc45e01b79175dde4b52548becea4f9545b0a4428169", size = 151714 }, + { url = "https://files.pythonhosted.org/packages/7f/7d/dbeeac10eb61d5d8858d0bb51121a21050d281dc83af4c557f86da28746c/simplejson-3.19.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:08f9b443a94e72dd02c87098c96886d35790e79e46b24e67accafbf13b73d43b", size = 158777 }, + { url = "https://files.pythonhosted.org/packages/fc/8f/a98bdbb799c6a4a884b5823db31785a96ba895b4b0f4d8ac345d6fe98bbf/simplejson-3.19.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fa97278ae6614346b5ca41a45a911f37a3261b57dbe4a00602048652c862c28b", size = 154230 }, + { url = "https://files.pythonhosted.org/packages/b1/db/852eebceb85f969ae40e06babed1a93d3bacb536f187d7a80ff5823a5979/simplejson-3.19.3-cp312-cp312-win32.whl", hash = "sha256:ef28c3b328d29b5e2756903aed888960bc5df39b4c2eab157ae212f70ed5bf74", size = 74002 }, + { url = "https://files.pythonhosted.org/packages/fe/68/9f0e5df0651cb79ef83cba1378765a00ee8038e6201cc82b8e7178a7778e/simplejson-3.19.3-cp312-cp312-win_amd64.whl", hash = "sha256:1e662336db50ad665777e6548b5076329a94a0c3d4a0472971c588b3ef27de3a", size = 75596 }, + { url = "https://files.pythonhosted.org/packages/93/3a/5896821ed543899fcb9c4256c7e71bb110048047349a00f42bc8b8fb379f/simplejson-3.19.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:0959e6cb62e3994b5a40e31047ff97ef5c4138875fae31659bead691bed55896", size = 92931 }, + { url = "https://files.pythonhosted.org/packages/39/15/5d33d269440912ee40d856db0c8be2b91aba7a219690ab01f86cb0edd590/simplejson-3.19.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7a7bfad839c624e139a4863007233a3f194e7c51551081f9789cba52e4da5167", size = 75318 }, + { url = "https://files.pythonhosted.org/packages/2a/8d/2e7483a2bf7ec53acf7e012bafbda79d7b34f90471dda8e424544a59d484/simplejson-3.19.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:afab2f7f2486a866ff04d6d905e9386ca6a231379181a3838abce1f32fbdcc37", size = 74971 }, + { url = "https://files.pythonhosted.org/packages/4d/9d/9bdf34437c8834a7cf7246f85e9d5122e30579f512c10a0c2560e994294f/simplejson-3.19.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d00313681015ac498e1736b304446ee6d1c72c5b287cd196996dad84369998f7", size = 150112 }, + { url = "https://files.pythonhosted.org/packages/a7/e2/1f2ae2d89eaf85f6163c82150180aae5eaa18085cfaf892f8a57d4c51cbd/simplejson-3.19.3-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d936ae682d5b878af9d9eb4d8bb1fdd5e41275c8eb59ceddb0aeed857bb264a2", size = 158354 }, + { url = "https://files.pythonhosted.org/packages/60/83/26f610adf234c8492b3f30501e12f2271e67790f946c6898fe0c58aefe99/simplejson-3.19.3-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:01c6657485393f2e9b8177c77a7634f13ebe70d5e6de150aae1677d91516ce6b", size = 148455 }, + { url = "https://files.pythonhosted.org/packages/b5/4b/109af50006af77133653c55b5b91b4bd2d579ff8254ce11216c0b75f911b/simplejson-3.19.3-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a6a750d3c7461b1c47cfc6bba8d9e57a455e7c5f80057d2a82f738040dd1129", size = 152191 }, + { url = "https://files.pythonhosted.org/packages/75/dc/108872a8825cbd99ae6f4334e0490ff1580367baf12198bcaf988f6820ba/simplejson-3.19.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ea7a4a998c87c5674a27089e022110a1a08a7753f21af3baf09efe9915c23c3c", size = 149954 }, + { url = "https://files.pythonhosted.org/packages/eb/be/deec1d947a5d0472276ab4a4d1a9378dc5ee27f3dc9e54d4f62ffbad7a08/simplejson-3.19.3-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:6300680d83a399be2b8f3b0ef7ef90b35d2a29fe6e9c21438097e0938bbc1564", size = 151812 }, + { url = "https://files.pythonhosted.org/packages/e9/58/4ee130702d36b1551ef66e7587eefe56651f3669255bf748cd71691e2434/simplejson-3.19.3-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:ab69f811a660c362651ae395eba8ce84f84c944cea0df5718ea0ba9d1e4e7252", size = 158880 }, + { url = "https://files.pythonhosted.org/packages/0f/e1/59cc6a371b60f89e3498d9f4c8109f6b7359094d453f5fe80b2677b777b0/simplejson-3.19.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:256e09d0f94d9c3d177d9e95fd27a68c875a4baa2046633df387b86b652f5747", size = 154344 }, + { url = "https://files.pythonhosted.org/packages/79/45/1b36044670016f5cb25ebd92497427d2d1711ecb454d00f71eb9a00b77cc/simplejson-3.19.3-cp313-cp313-win32.whl", hash = "sha256:2c78293470313aefa9cfc5e3f75ca0635721fb016fb1121c1c5b0cb8cc74712a", size = 74002 }, + { url = "https://files.pythonhosted.org/packages/e2/58/b06226e6b0612f2b1fa13d5273551da259f894566b1eef32249ddfdcce44/simplejson-3.19.3-cp313-cp313-win_amd64.whl", hash = "sha256:3bbcdc438dc1683b35f7a8dc100960c721f922f9ede8127f63bed7dfded4c64c", size = 75599 }, + { url = "https://files.pythonhosted.org/packages/0d/e7/f9fafbd4f39793a20cc52e77bbd766f7384312526d402c382928dc7667f6/simplejson-3.19.3-py3-none-any.whl", hash = "sha256:49cc4c7b940d43bd12bf87ec63f28cbc4964fc4e12c031cc8cd01650f43eb94e", size = 57004 }, +] + +[[package]] +name = "six" +version = "1.16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/71/39/171f1c67cd00715f190ba0b100d606d440a28c93c7714febeca8b79af85e/six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926", size = 34041 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d9/5a/e7c31adbe875f2abbb91bd84cf2dc52d792b5a01506781dbcf25c91daf11/six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254", size = 11053 }, +] + +[[package]] +name = "sniffio" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 }, +] + +[[package]] +name = "starlette" +version = "0.40.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4b/cb/244daf0d7be4508099ad5bca3cdfe8b8b5538acd719c5f397f614e569fff/starlette-0.40.0.tar.gz", hash = "sha256:1a3139688fb298ce5e2d661d37046a66ad996ce94be4d4983be019a23a04ea35", size = 2573611 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0a/0f/64baf7a06492e8c12f5c4b49db286787a7255195df496fc21f5fd9eecffa/starlette-0.40.0-py3-none-any.whl", hash = "sha256:c494a22fae73805376ea6bf88439783ecfba9aac88a43911b48c653437e784c4", size = 73303 }, +] + +[[package]] +name = "tensorboard" +version = "2.17.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "absl-py" }, + { name = "grpcio" }, + { name = "markdown" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "protobuf" }, + { name = "setuptools" }, + { name = "six" }, + { name = "tensorboard-data-server" }, + { name = "werkzeug" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/d4/41/dccba8c5f955bc35b6110ff78574e4e5c8226ad62f08e732096c3861309b/tensorboard-2.17.1-py3-none-any.whl", hash = "sha256:253701a224000eeca01eee6f7e978aea7b408f60b91eb0babdb04e78947b773e", size = 5502989 }, +] + +[[package]] +name = "tensorboard-data-server" +version = "0.7.2" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/13/e503968fefabd4c6b2650af21e110aa8466fe21432cd7c43a84577a89438/tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb", size = 2356 }, + { url = "https://files.pythonhosted.org/packages/b7/85/dabeaf902892922777492e1d253bb7e1264cadce3cea932f7ff599e53fea/tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60", size = 4823598 }, + { url = "https://files.pythonhosted.org/packages/73/c6/825dab04195756cf8ff2e12698f22513b3db2f64925bdd41671bfb33aaa5/tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530", size = 6590363 }, +] + +[[package]] +name = "tensorflow" +version = "2.17.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "absl-py" }, + { name = "astunparse" }, + { name = "flatbuffers" }, + { name = "gast" }, + { name = "google-pasta" }, + { name = "grpcio" }, + { name = "h5py" }, + { name = "keras" }, + { name = "libclang" }, + { name = "ml-dtypes" }, + { name = "numpy" }, + { name = "opt-einsum" }, + { name = "packaging" }, + { name = "protobuf" }, + { name = "requests" }, + { name = "setuptools" }, + { name = "six" }, + { name = "tensorboard" }, + { name = "termcolor" }, + { name = "typing-extensions" }, + { name = "wrapt" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/cb/aab1d277dc1cb45166a1ac857bf535ef322437f4f6fbef55ad6dee31e00d/tensorflow-2.17.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:ee18b4fcd627c5e872eabb25092af6c808b6ec77948662c88fc5c89a60eb0211", size = 236281093 }, + { url = "https://files.pythonhosted.org/packages/e5/db/a365f4bdb6a0067ea9cc0a4b27757e2755627f3ff29dcca4cce3d274662d/tensorflow-2.17.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:72adfef0ee39dd641627906fd7b244fcf21bdd8a87216a998ed74d9c74653aff", size = 224039182 }, + { url = "https://files.pythonhosted.org/packages/d7/24/f7b5c130975303efa4a91341e294fd784b94910c2a4c0f0f0561e5fc7405/tensorflow-2.17.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ad7bfea6afb4ded3928ca5b24df9fda876cea4904c103a5163fcc0c3483e7a4", size = 601419499 }, + { url = "https://files.pythonhosted.org/packages/48/a7/2f03f6de3c4976db6d2a898c408feb491ca10399af1f6039d6bef3e6ba6a/tensorflow-2.17.0-cp312-cp312-win_amd64.whl", hash = "sha256:278bc80642d799adf08dc4e04f291aab603bba7457d50c1f9bc191ebbca83f43", size = 2045 }, +] + +[[package]] +name = "termcolor" +version = "2.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/37/72/88311445fd44c455c7d553e61f95412cf89054308a1aa2434ab835075fc5/termcolor-2.5.0.tar.gz", hash = "sha256:998d8d27da6d48442e8e1f016119076b690d962507531df4890fcd2db2ef8a6f", size = 13057 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7f/be/df630c387a0a054815d60be6a97eb4e8f17385d5d6fe660e1c02750062b4/termcolor-2.5.0-py3-none-any.whl", hash = "sha256:37b17b5fc1e604945c2642c872a3764b5d547a48009871aea3edd3afa180afb8", size = 7755 }, +] + +[[package]] +name = "tf-keras" +version = "2.17.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "tensorflow" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/dc/2b/d647100a2e80d159b020f1dbc2ef2c6787ed33c914951a63b3c88cd805d0/tf_keras-2.17.0.tar.gz", hash = "sha256:fda97c18da30da0f72a5a7e80f3eee343b09f4c206dad6c57c944fb2cd18560e", size = 1260098 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/21/8b/75f7572ec0273ed8da50bc19defe08aaaafcc15fda3407db53f49acec814/tf_keras-2.17.0-py3-none-any.whl", hash = "sha256:cc97717e4dc08487f327b0740a984043a9e0123c7a4e21206711669d3ec41c88", size = 1724905 }, +] + +[[package]] +name = "threadpoolctl" +version = "3.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bd/55/b5148dcbf72f5cde221f8bfe3b6a540da7aa1842f6b491ad979a6c8b84af/threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107", size = 41936 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/2c/ffbf7a134b9ab11a67b0cf0726453cedd9c5043a4fe7a35d1cefa9a1bcfb/threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467", size = 18414 }, +] + +[[package]] +name = "tinytimer" +version = "0.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5c/87/75d3fbe15aabd45d9b70702241787cf1f7f30dd9fabcd9bc89d828c7661d/tinytimer-0.0.0.tar.gz", hash = "sha256:6ad13c8f01ab6094e58081a5367ffc4c5831f2d6b29034d2434d8ae106308fa5", size = 2069 } + +[[package]] +name = "tomlkit" +version = "0.13.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b1/09/a439bec5888f00a54b8b9f05fa94d7f901d6735ef4e55dcec9bc37b5d8fa/tomlkit-0.13.2.tar.gz", hash = "sha256:fff5fe59a87295b278abd31bec92c15d9bc4a06885ab12bcea52c71119392e79", size = 192885 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/b6/a447b5e4ec71e13871be01ba81f5dfc9d0af7e473da256ff46bc0e24026f/tomlkit-0.13.2-py3-none-any.whl", hash = "sha256:7a974427f6e119197f670fbbbeae7bef749a6c14e793db934baefc1b5f03efde", size = 37955 }, +] + +[[package]] +name = "tqdm" +version = "4.66.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "platform_system == 'Windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/58/83/6ba9844a41128c62e810fddddd72473201f3eacde02046066142a2d96cc5/tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad", size = 169504 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/5d/acf5905c36149bbaec41ccf7f2b68814647347b72075ac0b1fe3022fdc73/tqdm-4.66.5-py3-none-any.whl", hash = "sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd", size = 78351 }, +] + +[[package]] +name = "typechecks" +version = "0.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/62/21/15129201c1f52f6af1e7809e96dce46da4b406c2c07fe9425e92f51edc5c/typechecks-0.1.0.tar.gz", hash = "sha256:7d801a6018f60d2a10aa3debc3af65f590c96c455de67159f39b9b183107c83b", size = 3397 } + +[[package]] +name = "typer" +version = "0.12.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "rich" }, + { name = "shellingham" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c5/58/a79003b91ac2c6890fc5d90145c662fd5771c6f11447f116b63300436bc9/typer-0.12.5.tar.gz", hash = "sha256:f592f089bedcc8ec1b974125d64851029c3b1af145f04aca64d69410f0c9b722", size = 98953 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/2b/886d13e742e514f704c33c4caa7df0f3b89e5a25ef8db02aa9ca3d9535d5/typer-0.12.5-py3-none-any.whl", hash = "sha256:62fe4e471711b147e3365034133904df3e235698399bc4de2b36c8579298d52b", size = 47288 }, +] + +[[package]] +name = "typing-extensions" +version = "4.12.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/df/db/f35a00659bc03fec321ba8bce9420de607a1d37f8342eee1863174c69557/typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8", size = 85321 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d", size = 37438 }, +] + +[[package]] +name = "tzdata" +version = "2024.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e1/34/943888654477a574a86a98e9896bae89c7aa15078ec29f490fef2f1e5384/tzdata-2024.2.tar.gz", hash = "sha256:7d85cc416e9382e69095b7bdf4afd9e3880418a2413feec7069d533d6b4e31cc", size = 193282 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/ab/7e5f53c3b9d14972843a647d8d7a853969a58aecc7559cb3267302c94774/tzdata-2024.2-py2.py3-none-any.whl", hash = "sha256:a48093786cdcde33cad18c2555e8532f34422074448fbc874186f0abd79565cd", size = 346586 }, +] + +[[package]] +name = "urllib3" +version = "2.2.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ed/63/22ba4ebfe7430b76388e7cd448d5478814d3032121827c12a2cc287e2260/urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9", size = 300677 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ce/d9/5f4c13cecde62396b0d3fe530a50ccea91e7dfc1ccf0e09c228841bb5ba8/urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac", size = 126338 }, +] + +[[package]] +name = "uvicorn" +version = "0.32.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e0/fc/1d785078eefd6945f3e5bab5c076e4230698046231eb0f3747bc5c8fa992/uvicorn-0.32.0.tar.gz", hash = "sha256:f78b36b143c16f54ccdb8190d0a26b5f1901fe5a3c777e1ab29f26391af8551e", size = 77564 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/eb/14/78bd0e95dd2444b6caacbca2b730671d4295ccb628ef58b81bee903629df/uvicorn-0.32.0-py3-none-any.whl", hash = "sha256:60b8f3a5ac027dcd31448f411ced12b5ef452c646f76f02f8cc3f25d8d26fd82", size = 63723 }, +] + +[package.optional-dependencies] +standard = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "httptools" }, + { name = "python-dotenv" }, + { name = "pyyaml" }, + { name = "uvloop", marker = "platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32'" }, + { name = "watchfiles" }, + { name = "websockets" }, +] + +[[package]] +name = "uvloop" +version = "0.21.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/af/c0/854216d09d33c543f12a44b393c402e89a920b1a0a7dc634c42de91b9cf6/uvloop-0.21.0.tar.gz", hash = "sha256:3bf12b0fda68447806a7ad847bfa591613177275d35b6724b1ee573faa3704e3", size = 2492741 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8c/4c/03f93178830dc7ce8b4cdee1d36770d2f5ebb6f3d37d354e061eefc73545/uvloop-0.21.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:359ec2c888397b9e592a889c4d72ba3d6befba8b2bb01743f72fffbde663b59c", size = 1471284 }, + { url = "https://files.pythonhosted.org/packages/43/3e/92c03f4d05e50f09251bd8b2b2b584a2a7f8fe600008bcc4523337abe676/uvloop-0.21.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f7089d2dc73179ce5ac255bdf37c236a9f914b264825fdaacaded6990a7fb4c2", size = 821349 }, + { url = "https://files.pythonhosted.org/packages/a6/ef/a02ec5da49909dbbfb1fd205a9a1ac4e88ea92dcae885e7c961847cd51e2/uvloop-0.21.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:baa4dcdbd9ae0a372f2167a207cd98c9f9a1ea1188a8a526431eef2f8116cc8d", size = 4580089 }, + { url = "https://files.pythonhosted.org/packages/06/a7/b4e6a19925c900be9f98bec0a75e6e8f79bb53bdeb891916609ab3958967/uvloop-0.21.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86975dca1c773a2c9864f4c52c5a55631038e387b47eaf56210f873887b6c8dc", size = 4693770 }, + { url = "https://files.pythonhosted.org/packages/ce/0c/f07435a18a4b94ce6bd0677d8319cd3de61f3a9eeb1e5f8ab4e8b5edfcb3/uvloop-0.21.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:461d9ae6660fbbafedd07559c6a2e57cd553b34b0065b6550685f6653a98c1cb", size = 4451321 }, + { url = "https://files.pythonhosted.org/packages/8f/eb/f7032be105877bcf924709c97b1bf3b90255b4ec251f9340cef912559f28/uvloop-0.21.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:183aef7c8730e54c9a3ee3227464daed66e37ba13040bb3f350bc2ddc040f22f", size = 4659022 }, + { url = "https://files.pythonhosted.org/packages/3f/8d/2cbef610ca21539f0f36e2b34da49302029e7c9f09acef0b1c3b5839412b/uvloop-0.21.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:bfd55dfcc2a512316e65f16e503e9e450cab148ef11df4e4e679b5e8253a5281", size = 1468123 }, + { url = "https://files.pythonhosted.org/packages/93/0d/b0038d5a469f94ed8f2b2fce2434a18396d8fbfb5da85a0a9781ebbdec14/uvloop-0.21.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:787ae31ad8a2856fc4e7c095341cccc7209bd657d0e71ad0dc2ea83c4a6fa8af", size = 819325 }, + { url = "https://files.pythonhosted.org/packages/50/94/0a687f39e78c4c1e02e3272c6b2ccdb4e0085fda3b8352fecd0410ccf915/uvloop-0.21.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ee4d4ef48036ff6e5cfffb09dd192c7a5027153948d85b8da7ff705065bacc6", size = 4582806 }, + { url = "https://files.pythonhosted.org/packages/d2/19/f5b78616566ea68edd42aacaf645adbf71fbd83fc52281fba555dc27e3f1/uvloop-0.21.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3df876acd7ec037a3d005b3ab85a7e4110422e4d9c1571d4fc89b0fc41b6816", size = 4701068 }, + { url = "https://files.pythonhosted.org/packages/47/57/66f061ee118f413cd22a656de622925097170b9380b30091b78ea0c6ea75/uvloop-0.21.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bd53ecc9a0f3d87ab847503c2e1552b690362e005ab54e8a48ba97da3924c0dc", size = 4454428 }, + { url = "https://files.pythonhosted.org/packages/63/9a/0962b05b308494e3202d3f794a6e85abe471fe3cafdbcf95c2e8c713aabd/uvloop-0.21.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a5c39f217ab3c663dc699c04cbd50c13813e31d917642d459fdcec07555cc553", size = 4660018 }, +] + +[[package]] +name = "varcode" +version = "1.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "biopython" }, + { name = "memoized-property" }, + { name = "numpy" }, + { name = "pandas" }, + { name = "pyensembl" }, + { name = "pyvcf3" }, + { name = "sercol" }, + { name = "serializable" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/24/17/45c79abe87582824af2337a060be702507b00e9cd962c6c08a203669aeff/varcode-1.2.1.tar.gz", hash = "sha256:f2f0f608b266304cb6ceaa353357ea089cc6ae3f1fa15c8824b44e61fcf567cb", size = 85655 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a9/c6/b38509324c7fc26d5b617e4617a936abcb410d2cb8c1f6cde0dd0169d2ac/varcode-1.2.1-py3-none-any.whl", hash = "sha256:7f35b0ebb8c3752c014f7613e15bea38112f42a14652c4455a82b00716bfdc08", size = 120418 }, +] + +[[package]] +name = "watchfiles" +version = "0.24.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c8/27/2ba23c8cc85796e2d41976439b08d52f691655fdb9401362099502d1f0cf/watchfiles-0.24.0.tar.gz", hash = "sha256:afb72325b74fa7a428c009c1b8be4b4d7c2afedafb2982827ef2156646df2fe1", size = 37870 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/35/82/92a7bb6dc82d183e304a5f84ae5437b59ee72d48cee805a9adda2488b237/watchfiles-0.24.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:7211b463695d1e995ca3feb38b69227e46dbd03947172585ecb0588f19b0d87a", size = 374137 }, + { url = "https://files.pythonhosted.org/packages/87/91/49e9a497ddaf4da5e3802d51ed67ff33024597c28f652b8ab1e7c0f5718b/watchfiles-0.24.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4b8693502d1967b00f2fb82fc1e744df128ba22f530e15b763c8d82baee15370", size = 367733 }, + { url = "https://files.pythonhosted.org/packages/0d/d8/90eb950ab4998effea2df4cf3a705dc594f6bc501c5a353073aa990be965/watchfiles-0.24.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cdab9555053399318b953a1fe1f586e945bc8d635ce9d05e617fd9fe3a4687d6", size = 437322 }, + { url = "https://files.pythonhosted.org/packages/6c/a2/300b22e7bc2a222dd91fce121cefa7b49aa0d26a627b2777e7bdfcf1110b/watchfiles-0.24.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:34e19e56d68b0dad5cff62273107cf5d9fbaf9d75c46277aa5d803b3ef8a9e9b", size = 433409 }, + { url = "https://files.pythonhosted.org/packages/99/44/27d7708a43538ed6c26708bcccdde757da8b7efb93f4871d4cc39cffa1cc/watchfiles-0.24.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:41face41f036fee09eba33a5b53a73e9a43d5cb2c53dad8e61fa6c9f91b5a51e", size = 452142 }, + { url = "https://files.pythonhosted.org/packages/b0/ec/c4e04f755be003129a2c5f3520d2c47026f00da5ecb9ef1e4f9449637571/watchfiles-0.24.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5148c2f1ea043db13ce9b0c28456e18ecc8f14f41325aa624314095b6aa2e9ea", size = 469414 }, + { url = "https://files.pythonhosted.org/packages/c5/4e/cdd7de3e7ac6432b0abf282ec4c1a1a2ec62dfe423cf269b86861667752d/watchfiles-0.24.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7e4bd963a935aaf40b625c2499f3f4f6bbd0c3776f6d3bc7c853d04824ff1c9f", size = 472962 }, + { url = "https://files.pythonhosted.org/packages/27/69/e1da9d34da7fc59db358424f5d89a56aaafe09f6961b64e36457a80a7194/watchfiles-0.24.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c79d7719d027b7a42817c5d96461a99b6a49979c143839fc37aa5748c322f234", size = 425705 }, + { url = "https://files.pythonhosted.org/packages/e8/c1/24d0f7357be89be4a43e0a656259676ea3d7a074901f47022f32e2957798/watchfiles-0.24.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:32aa53a9a63b7f01ed32e316e354e81e9da0e6267435c7243bf8ae0f10b428ef", size = 612851 }, + { url = "https://files.pythonhosted.org/packages/c7/af/175ba9b268dec56f821639c9893b506c69fd999fe6a2e2c51de420eb2f01/watchfiles-0.24.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ce72dba6a20e39a0c628258b5c308779b8697f7676c254a845715e2a1039b968", size = 594868 }, + { url = "https://files.pythonhosted.org/packages/44/81/1f701323a9f70805bc81c74c990137123344a80ea23ab9504a99492907f8/watchfiles-0.24.0-cp312-none-win32.whl", hash = "sha256:d9018153cf57fc302a2a34cb7564870b859ed9a732d16b41a9b5cb2ebed2d444", size = 264109 }, + { url = "https://files.pythonhosted.org/packages/b4/0b/32cde5bc2ebd9f351be326837c61bdeb05ad652b793f25c91cac0b48a60b/watchfiles-0.24.0-cp312-none-win_amd64.whl", hash = "sha256:551ec3ee2a3ac9cbcf48a4ec76e42c2ef938a7e905a35b42a1267fa4b1645896", size = 277055 }, + { url = "https://files.pythonhosted.org/packages/4b/81/daade76ce33d21dbec7a15afd7479de8db786e5f7b7d249263b4ea174e08/watchfiles-0.24.0-cp312-none-win_arm64.whl", hash = "sha256:b52a65e4ea43c6d149c5f8ddb0bef8d4a1e779b77591a458a893eb416624a418", size = 266169 }, + { url = "https://files.pythonhosted.org/packages/30/dc/6e9f5447ae14f645532468a84323a942996d74d5e817837a5c8ce9d16c69/watchfiles-0.24.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:3d2e3ab79a1771c530233cadfd277fcc762656d50836c77abb2e5e72b88e3a48", size = 373764 }, + { url = "https://files.pythonhosted.org/packages/79/c0/c3a9929c372816c7fc87d8149bd722608ea58dc0986d3ef7564c79ad7112/watchfiles-0.24.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:327763da824817b38ad125dcd97595f942d720d32d879f6c4ddf843e3da3fe90", size = 367873 }, + { url = "https://files.pythonhosted.org/packages/2e/11/ff9a4445a7cfc1c98caf99042df38964af12eed47d496dd5d0d90417349f/watchfiles-0.24.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd82010f8ab451dabe36054a1622870166a67cf3fce894f68895db6f74bbdc94", size = 438381 }, + { url = "https://files.pythonhosted.org/packages/48/a3/763ba18c98211d7bb6c0f417b2d7946d346cdc359d585cc28a17b48e964b/watchfiles-0.24.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d64ba08db72e5dfd5c33be1e1e687d5e4fcce09219e8aee893a4862034081d4e", size = 432809 }, + { url = "https://files.pythonhosted.org/packages/30/4c/616c111b9d40eea2547489abaf4ffc84511e86888a166d3a4522c2ba44b5/watchfiles-0.24.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1cf1f6dd7825053f3d98f6d33f6464ebdd9ee95acd74ba2c34e183086900a827", size = 451801 }, + { url = "https://files.pythonhosted.org/packages/b6/be/d7da83307863a422abbfeb12903a76e43200c90ebe5d6afd6a59d158edea/watchfiles-0.24.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43e3e37c15a8b6fe00c1bce2473cfa8eb3484bbeecf3aefbf259227e487a03df", size = 468886 }, + { url = "https://files.pythonhosted.org/packages/1d/d3/3dfe131ee59d5e90b932cf56aba5c996309d94dafe3d02d204364c23461c/watchfiles-0.24.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:88bcd4d0fe1d8ff43675360a72def210ebad3f3f72cabfeac08d825d2639b4ab", size = 472973 }, + { url = "https://files.pythonhosted.org/packages/42/6c/279288cc5653a289290d183b60a6d80e05f439d5bfdfaf2d113738d0f932/watchfiles-0.24.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:999928c6434372fde16c8f27143d3e97201160b48a614071261701615a2a156f", size = 425282 }, + { url = "https://files.pythonhosted.org/packages/d6/d7/58afe5e85217e845edf26d8780c2d2d2ae77675eeb8d1b8b8121d799ce52/watchfiles-0.24.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:30bbd525c3262fd9f4b1865cb8d88e21161366561cd7c9e1194819e0a33ea86b", size = 612540 }, + { url = "https://files.pythonhosted.org/packages/6d/d5/b96eeb9fe3fda137200dd2f31553670cbc731b1e13164fd69b49870b76ec/watchfiles-0.24.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:edf71b01dec9f766fb285b73930f95f730bb0943500ba0566ae234b5c1618c18", size = 593625 }, + { url = "https://files.pythonhosted.org/packages/c1/e5/c326fe52ee0054107267608d8cea275e80be4455b6079491dfd9da29f46f/watchfiles-0.24.0-cp313-none-win32.whl", hash = "sha256:f4c96283fca3ee09fb044f02156d9570d156698bc3734252175a38f0e8975f07", size = 263899 }, + { url = "https://files.pythonhosted.org/packages/a6/8b/8a7755c5e7221bb35fe4af2dc44db9174f90ebf0344fd5e9b1e8b42d381e/watchfiles-0.24.0-cp313-none-win_amd64.whl", hash = "sha256:a974231b4fdd1bb7f62064a0565a6b107d27d21d9acb50c484d2cdba515b9366", size = 276622 }, +] + +[[package]] +name = "websockets" +version = "13.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e2/73/9223dbc7be3dcaf2a7bbf756c351ec8da04b1fa573edaf545b95f6b0c7fd/websockets-13.1.tar.gz", hash = "sha256:a3b3366087c1bc0a2795111edcadddb8b3b59509d5db5d7ea3fdd69f954a8878", size = 158549 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/df/46/c426282f543b3c0296cf964aa5a7bb17e984f58dde23460c3d39b3148fcf/websockets-13.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9d75baf00138f80b48f1eac72ad1535aac0b6461265a0bcad391fc5aba875cfc", size = 157821 }, + { url = "https://files.pythonhosted.org/packages/aa/85/22529867010baac258da7c45848f9415e6cf37fef00a43856627806ffd04/websockets-13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9b6f347deb3dcfbfde1c20baa21c2ac0751afaa73e64e5b693bb2b848efeaa49", size = 155480 }, + { url = "https://files.pythonhosted.org/packages/29/2c/bdb339bfbde0119a6e84af43ebf6275278698a2241c2719afc0d8b0bdbf2/websockets-13.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:de58647e3f9c42f13f90ac7e5f58900c80a39019848c5547bc691693098ae1bd", size = 155715 }, + { url = "https://files.pythonhosted.org/packages/9f/d0/8612029ea04c5c22bf7af2fd3d63876c4eaeef9b97e86c11972a43aa0e6c/websockets-13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1b54689e38d1279a51d11e3467dd2f3a50f5f2e879012ce8f2d6943f00e83f0", size = 165647 }, + { url = "https://files.pythonhosted.org/packages/56/04/1681ed516fa19ca9083f26d3f3a302257e0911ba75009533ed60fbb7b8d1/websockets-13.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf1781ef73c073e6b0f90af841aaf98501f975d306bbf6221683dd594ccc52b6", size = 164592 }, + { url = "https://files.pythonhosted.org/packages/38/6f/a96417a49c0ed132bb6087e8e39a37db851c70974f5c724a4b2a70066996/websockets-13.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d23b88b9388ed85c6faf0e74d8dec4f4d3baf3ecf20a65a47b836d56260d4b9", size = 165012 }, + { url = "https://files.pythonhosted.org/packages/40/8b/fccf294919a1b37d190e86042e1a907b8f66cff2b61e9befdbce03783e25/websockets-13.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3c78383585f47ccb0fcf186dcb8a43f5438bd7d8f47d69e0b56f71bf431a0a68", size = 165311 }, + { url = "https://files.pythonhosted.org/packages/c1/61/f8615cf7ce5fe538476ab6b4defff52beb7262ff8a73d5ef386322d9761d/websockets-13.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:d6d300f8ec35c24025ceb9b9019ae9040c1ab2f01cddc2bcc0b518af31c75c14", size = 164692 }, + { url = "https://files.pythonhosted.org/packages/5c/f1/a29dd6046d3a722d26f182b783a7997d25298873a14028c4760347974ea3/websockets-13.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a9dcaf8b0cc72a392760bb8755922c03e17a5a54e08cca58e8b74f6902b433cf", size = 164686 }, + { url = "https://files.pythonhosted.org/packages/0f/99/ab1cdb282f7e595391226f03f9b498f52109d25a2ba03832e21614967dfa/websockets-13.1-cp312-cp312-win32.whl", hash = "sha256:2f85cf4f2a1ba8f602298a853cec8526c2ca42a9a4b947ec236eaedb8f2dc80c", size = 158712 }, + { url = "https://files.pythonhosted.org/packages/46/93/e19160db48b5581feac8468330aa11b7292880a94a37d7030478596cc14e/websockets-13.1-cp312-cp312-win_amd64.whl", hash = "sha256:38377f8b0cdeee97c552d20cf1865695fcd56aba155ad1b4ca8779a5b6ef4ac3", size = 159145 }, + { url = "https://files.pythonhosted.org/packages/51/20/2b99ca918e1cbd33c53db2cace5f0c0cd8296fc77558e1908799c712e1cd/websockets-13.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a9ab1e71d3d2e54a0aa646ab6d4eebfaa5f416fe78dfe4da2839525dc5d765c6", size = 157828 }, + { url = "https://files.pythonhosted.org/packages/b8/47/0932a71d3d9c0e9483174f60713c84cee58d62839a143f21a2bcdbd2d205/websockets-13.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b9d7439d7fab4dce00570bb906875734df13d9faa4b48e261c440a5fec6d9708", size = 155487 }, + { url = "https://files.pythonhosted.org/packages/a9/60/f1711eb59ac7a6c5e98e5637fef5302f45b6f76a2c9d64fd83bbb341377a/websockets-13.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:327b74e915cf13c5931334c61e1a41040e365d380f812513a255aa804b183418", size = 155721 }, + { url = "https://files.pythonhosted.org/packages/6a/e6/ba9a8db7f9d9b0e5f829cf626ff32677f39824968317223605a6b419d445/websockets-13.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:325b1ccdbf5e5725fdcb1b0e9ad4d2545056479d0eee392c291c1bf76206435a", size = 165609 }, + { url = "https://files.pythonhosted.org/packages/c1/22/4ec80f1b9c27a0aebd84ccd857252eda8418ab9681eb571b37ca4c5e1305/websockets-13.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:346bee67a65f189e0e33f520f253d5147ab76ae42493804319b5716e46dddf0f", size = 164556 }, + { url = "https://files.pythonhosted.org/packages/27/ac/35f423cb6bb15600438db80755609d27eda36d4c0b3c9d745ea12766c45e/websockets-13.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91a0fa841646320ec0d3accdff5b757b06e2e5c86ba32af2e0815c96c7a603c5", size = 164993 }, + { url = "https://files.pythonhosted.org/packages/31/4e/98db4fd267f8be9e52e86b6ee4e9aa7c42b83452ea0ea0672f176224b977/websockets-13.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:18503d2c5f3943e93819238bf20df71982d193f73dcecd26c94514f417f6b135", size = 165360 }, + { url = "https://files.pythonhosted.org/packages/3f/15/3f0de7cda70ffc94b7e7024544072bc5b26e2c1eb36545291abb755d8cdb/websockets-13.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:a9cd1af7e18e5221d2878378fbc287a14cd527fdd5939ed56a18df8a31136bb2", size = 164745 }, + { url = "https://files.pythonhosted.org/packages/a1/6e/66b6b756aebbd680b934c8bdbb6dcb9ce45aad72cde5f8a7208dbb00dd36/websockets-13.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:70c5be9f416aa72aab7a2a76c90ae0a4fe2755c1816c153c1a2bcc3333ce4ce6", size = 164732 }, + { url = "https://files.pythonhosted.org/packages/35/c6/12e3aab52c11aeb289e3dbbc05929e7a9d90d7a9173958477d3ef4f8ce2d/websockets-13.1-cp313-cp313-win32.whl", hash = "sha256:624459daabeb310d3815b276c1adef475b3e6804abaf2d9d2c061c319f7f187d", size = 158709 }, + { url = "https://files.pythonhosted.org/packages/41/d8/63d6194aae711d7263df4498200c690a9c39fb437ede10f3e157a6343e0d/websockets-13.1-cp313-cp313-win_amd64.whl", hash = "sha256:c518e84bb59c2baae725accd355c8dc517b4a3ed8db88b4bc93c78dae2974bf2", size = 159144 }, + { url = "https://files.pythonhosted.org/packages/56/27/96a5cd2626d11c8280656c6c71d8ab50fe006490ef9971ccd154e0c42cd2/websockets-13.1-py3-none-any.whl", hash = "sha256:a9a396a6ad26130cdae92ae10c36af09d9bfe6cafe69670fd3b6da9b07b4044f", size = 152134 }, +] + +[[package]] +name = "werkzeug" +version = "3.0.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0f/e2/6dbcaab07560909ff8f654d3a2e5a60552d937c909455211b1b36d7101dc/werkzeug-3.0.4.tar.gz", hash = "sha256:34f2371506b250df4d4f84bfe7b0921e4762525762bbd936614909fe25cd7306", size = 803966 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/84/997bbf7c2bf2dc3f09565c6d0b4959fefe5355c18c4096cfd26d83e0785b/werkzeug-3.0.4-py3-none-any.whl", hash = "sha256:02c9eb92b7d6c06f31a782811505d2157837cea66aaede3e217c7c27c039476c", size = 227554 }, +] + +[[package]] +name = "wheel" +version = "0.44.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b7/a0/95e9e962c5fd9da11c1e28aa4c0d8210ab277b1ada951d2aee336b505813/wheel-0.44.0.tar.gz", hash = "sha256:a29c3f2817e95ab89aa4660681ad547c0e9547f20e75b0562fe7723c9a2a9d49", size = 100733 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1b/d1/9babe2ccaecff775992753d8686970b1e2755d21c8a63be73aba7a4e7d77/wheel-0.44.0-py3-none-any.whl", hash = "sha256:2376a90c98cc337d18623527a97c31797bd02bad0033d41547043a1cbfbe448f", size = 67059 }, +] + +[[package]] +name = "wrapt" +version = "1.16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/95/4c/063a912e20bcef7124e0df97282a8af3ff3e4b603ce84c481d6d7346be0a/wrapt-1.16.0.tar.gz", hash = "sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d", size = 53972 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/92/17/224132494c1e23521868cdd57cd1e903f3b6a7ba6996b7b8f077ff8ac7fe/wrapt-1.16.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b", size = 37614 }, + { url = "https://files.pythonhosted.org/packages/6a/d7/cfcd73e8f4858079ac59d9db1ec5a1349bc486ae8e9ba55698cc1f4a1dff/wrapt-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36", size = 38316 }, + { url = "https://files.pythonhosted.org/packages/7e/79/5ff0a5c54bda5aec75b36453d06be4f83d5cd4932cc84b7cb2b52cee23e2/wrapt-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73", size = 86322 }, + { url = "https://files.pythonhosted.org/packages/c4/81/e799bf5d419f422d8712108837c1d9bf6ebe3cb2a81ad94413449543a923/wrapt-1.16.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809", size = 79055 }, + { url = "https://files.pythonhosted.org/packages/62/62/30ca2405de6a20448ee557ab2cd61ab9c5900be7cbd18a2639db595f0b98/wrapt-1.16.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b", size = 87291 }, + { url = "https://files.pythonhosted.org/packages/49/4e/5d2f6d7b57fc9956bf06e944eb00463551f7d52fc73ca35cfc4c2cdb7aed/wrapt-1.16.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81", size = 90374 }, + { url = "https://files.pythonhosted.org/packages/a6/9b/c2c21b44ff5b9bf14a83252a8b973fb84923764ff63db3e6dfc3895cf2e0/wrapt-1.16.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9", size = 83896 }, + { url = "https://files.pythonhosted.org/packages/14/26/93a9fa02c6f257df54d7570dfe8011995138118d11939a4ecd82cb849613/wrapt-1.16.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c", size = 91738 }, + { url = "https://files.pythonhosted.org/packages/a2/5b/4660897233eb2c8c4de3dc7cefed114c61bacb3c28327e64150dc44ee2f6/wrapt-1.16.0-cp312-cp312-win32.whl", hash = "sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc", size = 35568 }, + { url = "https://files.pythonhosted.org/packages/5c/cc/8297f9658506b224aa4bd71906447dea6bb0ba629861a758c28f67428b91/wrapt-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8", size = 37653 }, + { url = "https://files.pythonhosted.org/packages/ff/21/abdedb4cdf6ff41ebf01a74087740a709e2edb146490e4d9beea054b0b7a/wrapt-1.16.0-py3-none-any.whl", hash = "sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1", size = 23362 }, +] From ad48bff9b20cc9a0708029cd0929bb284fddccc2 Mon Sep 17 00:00:00 2001 From: ElektrikSpark Date: Sat, 19 Oct 2024 21:02:55 +0000 Subject: [PATCH 3/5] feat: netmhciipan works, fixed duplication bug --- .../app/api/api_v1/endpoints/prediction.py | 3 + .../src/app/schemas/mhc_ii_prediction.py | 6 +- .../src/app/services/background_tasks.py | 29 +++- apps/fastapi/src/app/services/inference.py | 38 ++++- apps/fastapi/src/app/services/postprocess.py | 134 ++++++++++++++++-- .../forms/conformational-b-structure-form.tsx | 12 +- .../peptides/forms/linear-b-form.tsx | 8 +- .../components/peptides/forms/mhc-i-form.tsx | 36 ++--- .../components/peptides/forms/mhc-ii-form.tsx | 40 +++--- .../tables/mhc-peptide-dialog-cell.tsx | 2 +- .../src/app/services/inference.py | 98 ++++++------- 11 files changed, 280 insertions(+), 126 deletions(-) diff --git a/apps/fastapi/src/app/api/api_v1/endpoints/prediction.py b/apps/fastapi/src/app/api/api_v1/endpoints/prediction.py index e97263d..02d53cc 100644 --- a/apps/fastapi/src/app/api/api_v1/endpoints/prediction.py +++ b/apps/fastapi/src/app/api/api_v1/endpoints/prediction.py @@ -44,6 +44,7 @@ async def create_conformational_b_prediction( chain=prediction_in.chain, is_structure_based=prediction_in.is_structure_based, prediction_type="conformational-b", + user_id=user.id, db=db, ) @@ -71,6 +72,7 @@ async def create_linear_b_prediction( job_id=job.id, sequence=prediction_in.sequence, prediction_type="linear-b", + user_id=user.id, db=db, ) @@ -128,6 +130,7 @@ async def create_mhc_ii_prediction( sequence=prediction_in.sequence, alleles=prediction_in.alleles, prediction_type="mhc-ii", + user_id=user.id, db=db, ) diff --git a/apps/fastapi/src/app/schemas/mhc_ii_prediction.py b/apps/fastapi/src/app/schemas/mhc_ii_prediction.py index 8ad0155..b05efaf 100644 --- a/apps/fastapi/src/app/schemas/mhc_ii_prediction.py +++ b/apps/fastapi/src/app/schemas/mhc_ii_prediction.py @@ -8,9 +8,9 @@ # Define the structure of the CSV data for MHC-II as a Pydantic model class MhcIIPredictionResult(BaseModel): Peptide_Sequence: str - ClassI_TCR_Recognition: Optional[float] = Field(default=None) - ClassI_MHC_Binding_Affinity: Optional[str] = Field(default="") - ClassI_pMHC_Stability: Optional[str] = Field(default="") + ClassII_TCR_Recognition: Optional[float] = Field(default=None) + ClassII_MHC_Binding_Affinity: Optional[str] = Field(default="") + ClassII_pMHC_Stability: Optional[str] = Field(default="") Best_Binding_Affinity: Optional[str] = Field(default="") Best_pMHC_Stability: Optional[str] = Field(default="") diff --git a/apps/fastapi/src/app/services/background_tasks.py b/apps/fastapi/src/app/services/background_tasks.py index 826d5f6..bfaaae0 100644 --- a/apps/fastapi/src/app/services/background_tasks.py +++ b/apps/fastapi/src/app/services/background_tasks.py @@ -7,14 +7,17 @@ from app.crud.crud_conformational_b_prediction import crud_conformational_b_prediction from app.crud.crud_job import crud_job from app.crud.crud_linear_b_prediction import crud_linear_b_prediction -from app.crud.crud_mhc_ii_prediction import crud_mhc_ii_prediction -from app.services.inference import run_netmhci_binding_affinity_classI +from app.services.inference import ( + run_netmhci_binding_affinity_classI, + run_netmhcii_binding_affinity_classII, +) from app.services.postprocess import ( postprocess_mhc_i_prediction, + postprocess_mhc_ii_prediction, process_classI_results, + process_classII_results, process_conformational_b_prediction, process_linear_b_prediction, - process_mhc_ii_prediction, ) from app.services.preprocess import preprocess_protein_sequence @@ -90,12 +93,24 @@ async def process_and_update_prediction( prediction_type=prediction_type, ) elif prediction_type == "mhc-ii": - # Step 1: Split protein sequence into peptides + # Step 1: Split protein sequence into peptides (preprocessing) peptides = preprocess_protein_sequence(sequence, prediction_type) - results = await process_mhc_ii_prediction(sequence=sequence) - await crud_mhc_ii_prediction.update_result( - db=db, job_id=job_id, result=results + # Step 2: Run NetMHCIIpan-4.3 binding affinity predictions (inference) + netmhcii_results = await run_netmhcii_binding_affinity_classII( + peptides, alleles + ) + + # Step 3: Process NetMHCIIpan-4.3 results (postprocessing) + processed_results = await process_classII_results(netmhcii_results) + + # Step 4: Process the results (postprocessing) + await postprocess_mhc_ii_prediction( + db=db, + job_id=job_id, + results=processed_results, # Passing results from inference + user_id=user_id, + prediction_type=prediction_type, ) else: raise HTTPException(status_code=400, detail="Unsupported prediction type.") diff --git a/apps/fastapi/src/app/services/inference.py b/apps/fastapi/src/app/services/inference.py index c092671..aa6df51 100644 --- a/apps/fastapi/src/app/services/inference.py +++ b/apps/fastapi/src/app/services/inference.py @@ -42,7 +42,9 @@ def get_sagemaker_predictions(requests, endpoint_name, model_name, model_type="m return responses -timeout = httpx.Timeout(10.0, read=60.0) # 10 seconds connect, 60 seconds read timeout +timeout = httpx.Timeout( + 10.0, read=3000.0 +) # 10 seconds connect, 50 minutes read timeout async def run_netmhci_binding_affinity_classI( @@ -78,3 +80,37 @@ async def run_netmhci_binding_affinity_classI( except Exception as e: logger.error(f"Unexpected error: {e}") return {"error": str(e)} + + +async def run_netmhcii_binding_affinity_classII( + peptides: List[str], alleles: List[str] +) -> List[Dict[str, Any]]: + """ + Calls the NetMHCIIpan API with peptides and alleles to get class II binding affinity results. + + Args: + peptides (List[str]): List of peptide sequences. + alleles (List[str]): List of HLA alleles for predictions. + + Returns: + List[Dict[str, Any]]: List of prediction results. + """ + payload = {"peptides": peptides, "alleles": alleles} + + try: + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.post(CLASSII_URL, json=payload) + response.raise_for_status() + results = response.json() + return results + except (httpx.RequestError, httpx.HTTPStatusError) as e: + logger.error(f"Request error: {e}") + return [{"peptides": peptides, "error": str(e)}] + except httpx.HTTPStatusError as e: + logger.error( + f"HTTP error: {e.response.status_code}, details: {e.response.text}" + ) + return {"error": e.response.text} + except Exception as e: + logger.error(f"Unexpected error: {e}") + return {"error": str(e)} diff --git a/apps/fastapi/src/app/services/postprocess.py b/apps/fastapi/src/app/services/postprocess.py index 40c110b..895977d 100644 --- a/apps/fastapi/src/app/services/postprocess.py +++ b/apps/fastapi/src/app/services/postprocess.py @@ -10,6 +10,7 @@ from app.core.config import settings from app.core.utils import generate_csv_key, read_s3_csv, upload_csv_to_s3 from app.crud.crud_mhc_i_prediction import crud_mhc_i_prediction +from app.crud.crud_mhc_ii_prediction import crud_mhc_ii_prediction from app.schemas.conformational_b_prediction import PredictionResult from app.schemas.linear_b_prediction import LBPredictionResult from app.schemas.mhc_i_prediction import MhcIPredictionResult @@ -209,19 +210,132 @@ async def postprocess_mhc_i_prediction( ) -async def process_mhc_ii_prediction(sequence: str) -> List[MhcIIPredictionResult]: +async def postprocess_mhc_ii_prediction( + db: AsyncClient, + job_id: str, + results: List[MhcIIPredictionResult], # Results from inference + user_id: str, + prediction_type: str, +): """ - Process an MHC-II prediction by reading a CSV file from S3 and validating results. + Consolidated postprocessing function for MHC-II predictions. + - Processes the prediction results + - Uploads results as a CSV to S3 + - Updates the database with the results and CSV download URL """ - csv_filename = "class_II.csv" - s3_key = f"data/{csv_filename}" # The S3 key for the file + # Step 1: Process the results + processed_results = results - # Use the utility function to read the CSV and validate rows - results = read_s3_csv(settings.S3_BUCKET_NAME, s3_key, MhcIIPredictionResult) + # Step 2: Generate a unique CSV filename and upload to S3 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + s3_key = generate_csv_key( + user_id=user_id, + job_id=job_id, + timestamp=timestamp, + prediction_type=prediction_type, + ) + await upload_csv_to_s3(processed_results, s3_key) - if not results: - raise HTTPException( - status_code=404, detail=f"CSV file not found in S3 for sequence {sequence}." + # Step 3: Update the database with results and CSV URL + await crud_mhc_ii_prediction.update_result( + db=db, + job_id=job_id, + result=processed_results, + csv_download_url=f"https://{settings.S3_BUCKET_NAME}.s3.amazonaws.com/{s3_key}", + ) + + +async def process_classII_results( + results: List[Dict[str, Any]], +) -> List[MhcIIPredictionResult]: + """ + Processes the Class II results returned by the IEDB API or similar prediction tool. + - Extracts relevant peptide, allele, and affinity data. + - Formats the results and calculates the best binding affinity. + """ + peptide_data = {} + + # Check if there's an error in the results + if isinstance(results, dict) and "error" in results: + logger.error(f"Error in results: {results['error']}") + raise HTTPException(status_code=400, detail=results["error"]) + + for res in results: + if "result" in res: + # Check if the result is a list of dictionaries instead of a string + if isinstance(res["result"], list): + # Handle the case where the result is a list of dictionaries + df = pd.DataFrame( + res["result"] + ) # Directly convert the list of dicts to a DataFrame + else: + try: + # Handle the case where the result is a string + df = pd.read_csv(StringIO(res["result"]), sep="\t") + except Exception as e: + logger.error(f"Error reading data for result {res}: {str(e)}") + raise HTTPException( + status_code=500, detail="Error processing results." + ) + + # Check for required columns in the results (adjust for class II) + if {"peptide", "allele", "affinity"}.issubset(df.columns): + for _, row in df.iterrows(): + peptide = row["peptide"] + allele = row["allele"] + affinity = row["affinity"] + + peptide_data.setdefault(peptide, {"binding_affinities": []}) + peptide_data[peptide]["binding_affinities"].append( + (allele, float(affinity)) + ) + else: + logger.warning(f"Unexpected columns in API response: {df.columns}") + else: + # Handle errors if any + for peptide in res["peptides"]: + peptide_data.setdefault(peptide, {"binding_affinities": []}) + logger.error(f"Error for peptide {peptide}: {res.get('error')}") + + # Format processed results + processed_results = [] + for peptide, data in peptide_data.items(): + binding_affinities = data.get("binding_affinities", []) + + # Ensure binding_affinities is a list and each element is a tuple of (allele, affinity) + if isinstance(binding_affinities, list) and all( + isinstance(x, tuple) and len(x) == 2 for x in binding_affinities + ): + logger.debug( + f"Binding affinities for peptide {peptide}: {binding_affinities}" + ) + + binding_affinity_str = "|".join( + [ + f"{allele}={affinity:.2f} nM" + for allele, affinity in binding_affinities + ] + ) + else: + logger.warning( + f"Binding affinities for peptide {peptide} is not formatted correctly: {binding_affinities}" + ) + binding_affinity_str = "" + + # Determine the best binding affinity (minimum affinity value) + best_binding_affinity = ( + f"{min(binding_affinities, key=lambda x: x[1])}" + if binding_affinities + else "" ) - return results + # Append the formatted result for this peptide + processed_results.append( + MhcIIPredictionResult( + Peptide_Sequence=peptide, + ClassII_MHC_Binding_Affinity=binding_affinity_str, + Best_Binding_Affinity=best_binding_affinity, + ) + ) + + return processed_results diff --git a/apps/nextjs/src/components/peptides/forms/conformational-b-structure-form.tsx b/apps/nextjs/src/components/peptides/forms/conformational-b-structure-form.tsx index ab4909d..0fe386d 100644 --- a/apps/nextjs/src/components/peptides/forms/conformational-b-structure-form.tsx +++ b/apps/nextjs/src/components/peptides/forms/conformational-b-structure-form.tsx @@ -46,8 +46,8 @@ const ConformationalBForm: React.FC = () => { sequence: "", pdbId: "", chain: "", - bcrRecognitionProbabilityMethod: "", - surfaceAccessibilityMethod: "", + bcrRecognitionProbabilityMethod: "esm-2", + surfaceAccessibilityMethod: "method-1", }, }); const { watch, setValue } = form; @@ -282,9 +282,7 @@ const ConformationalBForm: React.FC = () => { - Method 1 - Method 2 - Method 3 + ESM2 @@ -309,9 +307,7 @@ const ConformationalBForm: React.FC = () => { - Method A - Method B - Method C + Method 1 diff --git a/apps/nextjs/src/components/peptides/forms/linear-b-form.tsx b/apps/nextjs/src/components/peptides/forms/linear-b-form.tsx index a1cc357..b21fa20 100644 --- a/apps/nextjs/src/components/peptides/forms/linear-b-form.tsx +++ b/apps/nextjs/src/components/peptides/forms/linear-b-form.tsx @@ -43,8 +43,8 @@ const LinearBForm: React.FC = () => { schema: LinearBFormSchema, defaultValues: { sequence: "", - bCellImmunogenicityMethod: undefined, - bcrRecognitionProbabilityMethod: "", + bCellImmunogenicityMethod: "method-1", + bcrRecognitionProbabilityMethod: "bepipred-2", }, }); @@ -208,9 +208,7 @@ const LinearBForm: React.FC = () => { - Method 1 - Method 2 - Method 3 + BepiPred 2.0 diff --git a/apps/nextjs/src/components/peptides/forms/mhc-i-form.tsx b/apps/nextjs/src/components/peptides/forms/mhc-i-form.tsx index 09190fd..4cae82f 100644 --- a/apps/nextjs/src/components/peptides/forms/mhc-i-form.tsx +++ b/apps/nextjs/src/components/peptides/forms/mhc-i-form.tsx @@ -46,9 +46,9 @@ const MhcIForm: React.FC = () => { defaultValues: { sequence: "", alleles: [], - tcrRecognitionProbabilityMethod: "", - mhcBindingAffinityMethod: "", - pmhcStabilityMethod: "", + tcrRecognitionProbabilityMethod: "mix-tcr-pred", + mhcBindingAffinityMethod: "netmhcpan-4.1", + pmhcStabilityMethod: "method-1", }, }); @@ -201,6 +201,18 @@ const MhcIForm: React.FC = () => { const isHlaAIndeterminate = HLA_A_ALLELES.some((allele) => field.value.includes(allele)) && !isHlaASelected; + const isHlaBSelected = HLA_B_ALLELES.every((allele) => + field.value.includes(allele), + ); + const isHlaBIndeterminate = + HLA_B_ALLELES.some((allele) => field.value.includes(allele)) && + !isHlaBSelected; + const isHlaCSelected = HLA_C_ALLELES.every((allele) => + field.value.includes(allele), + ); + const isHlaCIndeterminate = + HLA_C_ALLELES.some((allele) => field.value.includes(allele)) && + !isHlaCSelected; return ( @@ -255,9 +267,7 @@ const MhcIForm: React.FC = () => { - field.value.includes(allele), - )} + checked={isHlaBSelected || isHlaBIndeterminate} onCheckedChange={(checked) => toggleGroup( HLA_B_ALLELES, @@ -302,9 +312,7 @@ const MhcIForm: React.FC = () => { - field.value.includes(allele), - )} + checked={isHlaCSelected || isHlaCIndeterminate} onCheckedChange={(checked) => toggleGroup( HLA_C_ALLELES, @@ -368,9 +376,7 @@ const MhcIForm: React.FC = () => { - Method 1 - Method 2 - Method 3 + MixTCRpred @@ -395,9 +401,7 @@ const MhcIForm: React.FC = () => { - Method 1 - Method 2 - Method 3 + NetMHCpan 4.1 @@ -423,8 +427,6 @@ const MhcIForm: React.FC = () => { Method 1 - Method 2 - Method 3 diff --git a/apps/nextjs/src/components/peptides/forms/mhc-ii-form.tsx b/apps/nextjs/src/components/peptides/forms/mhc-ii-form.tsx index fac3074..6ec15a5 100644 --- a/apps/nextjs/src/components/peptides/forms/mhc-ii-form.tsx +++ b/apps/nextjs/src/components/peptides/forms/mhc-ii-form.tsx @@ -50,9 +50,9 @@ const MhcIIForm: React.FC = () => { defaultValues: { sequence: "", alleles: [], - tcrRecognitionProbabilityMethod: "", - mhcBindingAffinityMethod: "", - pmhcStabilityMethod: "", + tcrRecognitionProbabilityMethod: "mix-tcr-pred", + mhcBindingAffinityMethod: "netmhciipan-4.3", + pmhcStabilityMethod: "method-1", }, }); @@ -199,6 +199,20 @@ const MhcIIForm: React.FC = () => { HLA_DRB_ALLELES.some((allele) => field.value.includes(allele), ) && !isDrbSelected; + const isDqaDqbSelected = HLA_DQA_DQB_ALLELES.every((allele) => + field.value.includes(allele), + ); + const isDqaDqbIndeterminate = + HLA_DQA_DQB_ALLELES.some((allele) => + field.value.includes(allele), + ) && !isDqaDqbSelected; + const isDpaDpbSelected = HLA_DPA_DPB_ALLELES.every((allele) => + field.value.includes(allele), + ); + const isDpaDpbIndeterminate = + HLA_DPA_DPB_ALLELES.some((allele) => + field.value.includes(allele), + ) && !isDpaDpbSelected; return ( @@ -253,9 +267,7 @@ const MhcIIForm: React.FC = () => { - field.value.includes(allele), - )} + checked={isDqaDqbSelected || isDqaDqbIndeterminate} onCheckedChange={(checked) => toggleGroup( HLA_DQA_DQB_ALLELES, @@ -300,9 +312,7 @@ const MhcIIForm: React.FC = () => { - field.value.includes(allele), - )} + checked={isDpaDpbSelected || isDpaDpbIndeterminate} onCheckedChange={(checked) => toggleGroup( HLA_DPA_DPB_ALLELES, @@ -366,9 +376,7 @@ const MhcIIForm: React.FC = () => { - Method 1 - Method 2 - Method 3 + MixTCRpred @@ -393,9 +401,9 @@ const MhcIIForm: React.FC = () => { - Method 1 - Method 2 - Method 3 + + NetMHCIIpan 4.3 + @@ -421,8 +429,6 @@ const MhcIIForm: React.FC = () => { Method 1 - Method 2 - Method 3 diff --git a/apps/nextjs/src/components/predictions/tables/mhc-peptide-dialog-cell.tsx b/apps/nextjs/src/components/predictions/tables/mhc-peptide-dialog-cell.tsx index c808f23..f093b90 100644 --- a/apps/nextjs/src/components/predictions/tables/mhc-peptide-dialog-cell.tsx +++ b/apps/nextjs/src/components/predictions/tables/mhc-peptide-dialog-cell.tsx @@ -81,7 +81,7 @@ export function MhcPeptideDialogCell({ rowData }: PeptideDialogProps) { {Peptide_Sequence} - +
diff --git a/apps/tools-fastapi/src/app/services/inference.py b/apps/tools-fastapi/src/app/services/inference.py index 886a7d0..e061994 100644 --- a/apps/tools-fastapi/src/app/services/inference.py +++ b/apps/tools-fastapi/src/app/services/inference.py @@ -4,7 +4,7 @@ from mhcnames import normalize_allele_name # Import the correct elution score versions of the predictors -from mhctools import NetMHCIIpan43, NetMHCpan41 +from mhctools import NetMHCIIpan43_BA, NetMHCpan41 logger = logging.getLogger(__name__) @@ -24,7 +24,7 @@ async def run_binding_predictions( List[Dict[str, Any]]: List of prediction results for each allele and peptide length. """ - # Validate alleles + # Validate and normalize alleles try: alleles = [normalize_allele_name(a) for a in alleles] logger.info(f"Normalized alleles: {alleles}") @@ -49,66 +49,50 @@ async def run_binding_predictions( # Initialize predictor based on the type predictor = None if predictor_type == "netmhcpan": - # Use NetMHCpan41_EL for elution score mode predictor = NetMHCpan41(alleles=alleles) elif predictor_type == "netmhciipan": - # Use NetMHCIIpan43_EL for elution score mode - predictor = NetMHCIIpan43(alleles=alleles) + predictor = NetMHCIIpan43_BA(alleles=alleles) if not predictor: raise ValueError(f"Unknown predictor type: {predictor_type}") - for allele in alleles: - for length, peptides_subset in peptides_by_length.items(): - try: - # Ensure valid peptide length before processing - if length is None or not isinstance(length, int): - logger.error( - f"Skipping invalid length for peptides: {peptides_subset}" - ) - continue - - logger.info( - f"Predicting subsequences for peptides: {peptides_subset} and length: {length}" - ) - - binding_predictions = predictor.predict_subsequences( - {f"seq{i}": seq for i, seq in enumerate(peptides_subset)}, - peptide_lengths=[length], - ) - - if not binding_predictions: - logger.error( - f"No predictions for allele {allele} and length {length}" - ) - continue - - # Convert predictions to a DataFrame and then to a dictionary - df = binding_predictions.to_dataframe() - logger.info(f"Prediction DataFrame: {df}") - logger.info(f"Converted to dict: {df.to_dict(orient='records')}") - results.append( - { - "allele": allele, - "length": length, - "peptides": peptides_subset, - "result": df.to_dict(orient="records"), - } - ) - logger.info( - f"Successfully predicted binding affinity for allele {allele} and length {length}." - ) - except Exception as e: - logger.error( - f"Error processing allele {allele} and length {length}: {e}" - ) - results.append( - { - "allele": allele, - "length": length, - "peptides": peptides_subset, - "error": str(e), - } - ) + # Iterate over the grouped peptides by length and run predictions + for length, peptides_subset in peptides_by_length.items(): + try: + logger.info( + f"Predicting subsequences for peptides: {peptides_subset} and length: {length}" + ) + + # Run prediction for all peptides of this length and all alleles at once + binding_predictions = predictor.predict_subsequences( + {f"seq{i}": seq for i, seq in enumerate(peptides_subset)}, + peptide_lengths=[length], + ) + + if not binding_predictions: + logger.error(f"No predictions for length {length}") + continue + + # Convert predictions to a DataFrame and then to a dictionary + df = binding_predictions.to_dataframe() + logger.info(f"Prediction DataFrame: {df}") + logger.info(f"Converted to dict: {df.to_dict(orient='records')}") + results.append( + { + "length": length, + "peptides": peptides_subset, + "result": df.to_dict(orient="records"), + } + ) + logger.info(f"Successfully predicted binding affinity for length {length}.") + except Exception as e: + logger.error(f"Error processing length {length}: {e}") + results.append( + { + "length": length, + "peptides": peptides_subset, + "error": str(e), + } + ) return results From f0277d03762c16ba5d80c56038fc1eec8fe0c224 Mon Sep 17 00:00:00 2001 From: ElektrikSpark Date: Sun, 20 Oct 2024 03:52:36 +0000 Subject: [PATCH 4/5] feat: training sema 1d on sagemaker works --- apps/tools-fastapi/bepipred-2.0.readme | 103 +++ .../SEMA-1D_finetuning_ESMV2_650.ipynb | 0 .../scripts/inference.py | 37 + .../scripts/lora-train.py | 0 .../scripts/requirements.txt | 9 + .../SEMA-1D_finetuning_ESM2_SageMaker.ipynb | 648 ++++++++++++++++++ .../sema-1d-sagemaker/scripts/inference.py | 140 +++- .../scripts/requirements.txt | 4 +- notebooks/sema-1d-sagemaker/scripts/train.py | 316 +++++++++ 9 files changed, 1230 insertions(+), 27 deletions(-) create mode 100644 apps/tools-fastapi/bepipred-2.0.readme rename notebooks/{sema-1d-sagemaker => sema-1d-sagemaker-old}/SEMA-1D_finetuning_ESMV2_650.ipynb (100%) create mode 100644 notebooks/sema-1d-sagemaker-old/scripts/inference.py rename notebooks/{sema-1d-sagemaker => sema-1d-sagemaker-old}/scripts/lora-train.py (100%) create mode 100644 notebooks/sema-1d-sagemaker-old/scripts/requirements.txt create mode 100644 notebooks/sema-1d-sagemaker/SEMA-1D_finetuning_ESM2_SageMaker.ipynb create mode 100644 notebooks/sema-1d-sagemaker/scripts/train.py diff --git a/apps/tools-fastapi/bepipred-2.0.readme b/apps/tools-fastapi/bepipred-2.0.readme new file mode 100644 index 0000000..db4a568 --- /dev/null +++ b/apps/tools-fastapi/bepipred-2.0.readme @@ -0,0 +1,103 @@ +BepiPred-2.0 +Sequence based B-cell epitope prediction tool +========================== + +0. Prerequisites + In order to install BepiPred-2.0, you will need the NetsurfP installed + either globally, so the "netsurfp" global command is available or + an environmental variable named NETSURFP_BIN with the path to the + binary of netsurfp. + + To download and install NetsurfP, visit the download page; + http://www.cbs.dtu.dk/cgi-bin/sw_request?netsurfp+1.0 + + +1. Installation + BepiPred-2.0 is compatible with python version 2.7 and above. It is highly adviced to use + a python virtual environment to avoid version conflicts: + cd your_VE_folder + virtualenv bp2 + source ./bp2/bin/activate + + If source fails try the following instead; + source ./bp2/bin/activate.csh + + *** NOTE *** + It is highly recommended if netsurfp is not globally avaiable to add the full + path to netsurfp to avoid to have to set the environmental variable everytime + you enter your virtual environment. + + If activate works, edit the activate file below the line "export VIRTUAL_ENV" + to contain; + export NETSURFP_BIN="Fullpath to netsurfp binary" + + If activate.csh works, edit the activate.csh file below the + line "setenv PATH "$VIRTUAL_ENV/bin:$PATH"" to contain; + + setenv NETSURFP_BIN "Fullpath to netsurfp binary" + ************* + + You can install BepiPred-2.0 by using pip: + + pip install bepipred-2.0/ + + You can check your installation with + + BepiPred-2.0 bepipred-2.0/bepipred2/data/example.fasta + + It will install all the python required packages. To uninstall, use + pip uninstall bepipred2 + +2. Usage + After installing, BepiPred-2.0 can be used both from commandline and + be intergrated into your own personal python script. How to use either + is described below. + + 2.1 Using BepiPred-2.0 from the commandline + The package will install the executable BepiPred-2.0. This + executable allows the prediction of B-cell epitopes using only + an antigen's amino acid sequence. + + *** NOTE *** + If you have multiple antigen sequences, it is highly recommended to submit + a single fasta file containing all and not call it separately, due to a + large computational performance gain. + ************ + + The basic usage is the following: + + BepiPred-2.0 sequence.fasta + + where sequence.fasta is a fasta format file containing the + amino acid sequence of an antigen or multiple antigens. + + The additional following options are available: + + usage: BepiPred-2.0 [-h] [-t THRESHOLD] fastafile + + positional arguments: + fastafile Fasta file containing antigen(s) + + optional arguments: + -h, --help show this help message and exit + -t THRESHOLD, --threshold THRESHOLD + Threshold on when to consider residues, epitope + residues + + 2.2 Integrating BepiPred-2.0 into a python script + After installation, it is possible to import BepiPred-2.0 into a python + script. To try it, open a python terminal in the environment where BepiPred-2.0 + is available. Then write the following; + + import bepipred2 as bp2 #imports BepiPred-2.0 + bp2.utils.RF_MODEL = bp2.utils.init_rf() #Unloads the random forest predictor + + seq = 'cdafvgtwKLVssenfddymkevgvgfatrkvagMAKpnmiisvngdlvtirsesTfkn' #Amino acid sequence + id = 'Example1' #Name/identifier of sequence + AG = bp2.Antigen(id, seq) #Sets up antigen class object + AG.pred_netsurfp() #Predicts surface and secondary structure of antigen + AG.get_features() #Sets up the feature space for predicting + AG.predict() #Predicts the epitopes of the antigen + + + The predictions of the sequence can then be found in class variable, AG.predicted. diff --git a/notebooks/sema-1d-sagemaker/SEMA-1D_finetuning_ESMV2_650.ipynb b/notebooks/sema-1d-sagemaker-old/SEMA-1D_finetuning_ESMV2_650.ipynb similarity index 100% rename from notebooks/sema-1d-sagemaker/SEMA-1D_finetuning_ESMV2_650.ipynb rename to notebooks/sema-1d-sagemaker-old/SEMA-1D_finetuning_ESMV2_650.ipynb diff --git a/notebooks/sema-1d-sagemaker-old/scripts/inference.py b/notebooks/sema-1d-sagemaker-old/scripts/inference.py new file mode 100644 index 0000000..d7ca1ff --- /dev/null +++ b/notebooks/sema-1d-sagemaker-old/scripts/inference.py @@ -0,0 +1,37 @@ +import torch +from transformers import AutoTokenizer, EsmForTokenClassification + + +# Load the model and tokenizer +def model_fn(model_dir): + # Load the EsmForTokenClassification model for regression + model = EsmForTokenClassification.from_pretrained( + model_dir, + device_map="auto", + num_labels=1, # Since it's a regression task + ) + tokenizer = AutoTokenizer.from_pretrained(model_dir) + + return model, tokenizer + + +# Prediction function +def predict_fn(data, model_and_tokenizer): + model, tokenizer = model_and_tokenizer + model.eval() + + # Prepare input data for the model + inputs = data.pop("inputs", data) + encoding = tokenizer(inputs, return_tensors="pt") + encoding = {k: v.to(model.device) for k, v in encoding.items()} + + # Run inference + with torch.no_grad(): + results = model(**encoding) + + # For regression, we directly use the logits as the predicted value + predictions = results.logits.cpu().numpy() + + return { + "predicted_contact_number": predictions[0].tolist() + } # Return prediction(s) as a list diff --git a/notebooks/sema-1d-sagemaker/scripts/lora-train.py b/notebooks/sema-1d-sagemaker-old/scripts/lora-train.py similarity index 100% rename from notebooks/sema-1d-sagemaker/scripts/lora-train.py rename to notebooks/sema-1d-sagemaker-old/scripts/lora-train.py diff --git a/notebooks/sema-1d-sagemaker-old/scripts/requirements.txt b/notebooks/sema-1d-sagemaker-old/scripts/requirements.txt new file mode 100644 index 0000000..6a1d080 --- /dev/null +++ b/notebooks/sema-1d-sagemaker-old/scripts/requirements.txt @@ -0,0 +1,9 @@ +accelerate +bitsandbytes +datasets +evaluate +nvidia-ml-py3 +peft +scikit-learn +transformers +torchinfo diff --git a/notebooks/sema-1d-sagemaker/SEMA-1D_finetuning_ESM2_SageMaker.ipynb b/notebooks/sema-1d-sagemaker/SEMA-1D_finetuning_ESM2_SageMaker.ipynb new file mode 100644 index 0000000..e8c3817 --- /dev/null +++ b/notebooks/sema-1d-sagemaker/SEMA-1D_finetuning_ESM2_SageMaker.ipynb @@ -0,0 +1,648 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SEMA-1D " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "SEMA-1D is a fine-tuned ESM-1v model aimed to predict epitope resiudes based on antigen protein sequence" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Set up Environment" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: datasets in /opt/conda/lib/python3.11/site-packages (2.21.0)\n", + "Requirement already satisfied: huggingface-hub in /opt/conda/lib/python3.11/site-packages (0.24.5)\n", + "Requirement already satisfied: s3fs==0.4.2 in /opt/conda/lib/python3.11/site-packages (0.4.2)\n", + "Requirement already satisfied: fair-esm in /opt/conda/lib/python3.11/site-packages (2.0.0)\n", + "Requirement already satisfied: botocore>=1.12.91 in /opt/conda/lib/python3.11/site-packages (from s3fs==0.4.2) (1.34.131)\n", + "Requirement already satisfied: fsspec>=0.6.0 in /opt/conda/lib/python3.11/site-packages (from s3fs==0.4.2) (2023.6.0)\n", + "Requirement already satisfied: filelock in /opt/conda/lib/python3.11/site-packages (from datasets) (3.15.4)\n", + "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.11/site-packages (from datasets) (1.26.4)\n", + "Requirement already satisfied: pyarrow>=15.0.0 in /opt/conda/lib/python3.11/site-packages (from datasets) (15.0.2)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /opt/conda/lib/python3.11/site-packages (from datasets) (0.3.8)\n", + "Requirement already satisfied: pandas in /opt/conda/lib/python3.11/site-packages (from datasets) (2.2.2)\n", + "Requirement already satisfied: requests>=2.32.2 in /opt/conda/lib/python3.11/site-packages (from datasets) (2.32.3)\n", + "Requirement already satisfied: tqdm>=4.66.3 in /opt/conda/lib/python3.11/site-packages (from datasets) (4.66.5)\n", + "Requirement already satisfied: xxhash in /opt/conda/lib/python3.11/site-packages (from datasets) (3.5.0)\n", + "Requirement already satisfied: multiprocess in /opt/conda/lib/python3.11/site-packages (from datasets) (0.70.16)\n", + "Requirement already satisfied: aiohttp in /opt/conda/lib/python3.11/site-packages (from datasets) (3.9.5)\n", + "Requirement already satisfied: packaging in /opt/conda/lib/python3.11/site-packages (from datasets) (24.1)\n", + "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.11/site-packages (from datasets) (6.0.2)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.11/site-packages (from huggingface-hub) (4.12.2)\n", + "Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in /opt/conda/lib/python3.11/site-packages (from botocore>=1.12.91->s3fs==0.4.2) (1.0.1)\n", + "Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /opt/conda/lib/python3.11/site-packages (from botocore>=1.12.91->s3fs==0.4.2) (2.9.0)\n", + "Requirement already satisfied: urllib3!=2.2.0,<3,>=1.25.4 in /opt/conda/lib/python3.11/site-packages (from botocore>=1.12.91->s3fs==0.4.2) (1.26.19)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets) (23.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets) (1.9.4)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.11/site-packages (from requests>=2.32.2->datasets) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.11/site-packages (from requests>=2.32.2->datasets) (3.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.11/site-packages (from requests>=2.32.2->datasets) (2024.7.4)\n", + "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.11/site-packages (from pandas->datasets) (2023.3)\n", + "Requirement already satisfied: tzdata>=2022.7 in /opt/conda/lib/python3.11/site-packages (from pandas->datasets) (2024.1)\n", + "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.11/site-packages (from python-dateutil<3.0.0,>=2.1->botocore>=1.12.91->s3fs==0.4.2) (1.16.0)\n", + "Note: you may need to restart the kernel to use updated packages.\n", + "\u001b[33mWARNING: Skipping tensorflow as it is not installed.\u001b[0m\u001b[33m\n", + "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install datasets huggingface-hub s3fs=='0.4.2' fair-esm\n", + "%pip uninstall tensorflow -y" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Assumed SageMaker role is arn:aws:iam::340752820161:role/service-role/AmazonSageMaker-ExecutionRole-20241011T160996\n", + "S3 path is s3://sagemaker-us-east-1-340752820161/esm2-sema-1d\n", + "Experiment name is esm2-sema-1d-2024-10-20-01-21-05\n" + ] + } + ], + "source": [ + "import boto3\n", + "import json\n", + "import os\n", + "import pandas as pd\n", + "import random\n", + "import sagemaker\n", + "from sagemaker.experiments.run import Run\n", + "from sagemaker.huggingface import HuggingFace, HuggingFaceModel\n", + "from sagemaker.inputs import TrainingInput\n", + "from time import strftime\n", + "\n", + "boto_session = boto3.session.Session()\n", + "sagemaker_session = sagemaker.session.Session(boto_session)\n", + "S3_BUCKET = sagemaker_session.default_bucket()\n", + "s3 = boto_session.client(\"s3\")\n", + "sagemaker_client = boto_session.client(\"sagemaker\")\n", + "REGION_NAME = sagemaker_session.boto_region_name\n", + "\n", + "try:\n", + " sagemaker_execution_role = sagemaker_session.get_execution_role()\n", + "except AttributeError:\n", + " NOTEBOOK_METADATA_FILE = \"/opt/ml/metadata/resource-metadata.json\"\n", + " with open(NOTEBOOK_METADATA_FILE, \"rb\") as f:\n", + " metadata = json.loads(f.read())\n", + " instance_name = metadata[\"ResourceName\"]\n", + " domain_id = metadata.get(\"DomainId\")\n", + " user_profile_name = metadata.get(\"UserProfileName\")\n", + " space_name = metadata.get(\"SpaceName\")\n", + " domain_desc = sagemaker_session.sagemaker_client.describe_domain(DomainId=domain_id)\n", + " if \"DefaultSpaceSettings\" in domain_desc:\n", + " sagemaker_execution_role = domain_desc[\"DefaultSpaceSettings\"][\"ExecutionRole\"]\n", + " else:\n", + " sagemaker_execution_role = domain_desc[\"DefaultUserSettings\"][\"ExecutionRole\"]\n", + "\n", + "print(f\"Assumed SageMaker role is {sagemaker_execution_role}\")\n", + "\n", + "S3_PREFIX = \"esm2-sema-1d\"\n", + "S3_PATH = sagemaker.s3.s3_path_join(\"s3://\", S3_BUCKET, S3_PREFIX)\n", + "print(f\"S3 path is {S3_PATH}\")\n", + "\n", + "EXPERIMENT_NAME = \"esm2-sema-1d-\" + strftime(\"%Y-%m-%d-%H-%M-%S\")\n", + "print(f\"Experiment name is {EXPERIMENT_NAME}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Build Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "# SEMA data URLs\n", + "train_data_url = 'https://raw.githubusercontent.com/AIRI-Institute/SEMAi/main/epitopes_prediction/data/sema_2.0/train_set.csv'\n", + "test_data_url = 'https://raw.githubusercontent.com/AIRI-Institute/SEMAi/main/epitopes_prediction/data/sema_2.0/test_set.csv'\n", + "\n", + "# Download the data locally\n", + "train_df = pd.read_csv(train_data_url)\n", + "test_df = pd.read_csv(test_data_url)\n", + "\n", + "# Save to local paths\n", + "train_local_path = 'train_set.csv'\n", + "test_local_path = 'test_set.csv'\n", + "train_df.to_csv(train_local_path, index=False)\n", + "test_df.to_csv(test_local_path, index=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we upload the processed training, test, and validation data to S3." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'s3://sagemaker-us-east-1-340752820161/esm2-sema-1d/data/test/test_set.csv'" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Upload to S3\n", + "sagemaker_session.upload_data(path=train_local_path, bucket=S3_BUCKET, key_prefix=f\"{S3_PREFIX}/data/train\")\n", + "sagemaker_session.upload_data(path=test_local_path, bucket=S3_BUCKET, key_prefix=f\"{S3_PREFIX}/data/test\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Train Model in SageMaker" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "hyperparameters = {\n", + " \"epochs\": 1,\n", + " \"per_device_train_batch_size\": 1,\n", + " \"per_device_eval_batch_size\": 1,\n", + " \"learning_rate\": 1e-5,\n", + " \"warmup_steps\": 0,\n", + " \"weight_decay\": 0.0,\n", + " \"gradient_accumulation_steps\": 1,\n", + " \"seed\": 42,\n", + "}\n", + "\n", + "metric_definitions = [\n", + " {\"Name\": \"epoch\", \"Regex\": \"'epoch': ([0-9.]*)\"},\n", + " {\n", + " \"Name\": \"max_gpu_mem\",\n", + " \"Regex\": \"Max GPU memory use during training: ([0-9.e-]*) MB\",\n", + " },\n", + " {\"Name\": \"train_loss\", \"Regex\": \"'loss': ([0-9.e-]*)\"},\n", + " {\n", + " \"Name\": \"train_samples_per_second\",\n", + " \"Regex\": \"'train_samples_per_second': ([0-9.e-]*)\",\n", + " },\n", + " {\"Name\": \"eval_loss\", \"Regex\": \"'eval_loss': ([0-9.e-]*)\"},\n", + " {\"Name\": \"eval_accuracy\", \"Regex\": \"'eval_accuracy': ([0-9.e-]*)\"},\n", + "]\n", + "\n", + "# Define the HuggingFace Estimator\n", + "hf_estimator = HuggingFace(\n", + " base_job_name=\"esm-2-sema-1d\",\n", + " entry_point='train.py',\n", + " source_dir='scripts',\n", + " instance_type='ml.p3.2xlarge',\n", + " instance_count=1,\n", + " transformers_version=\"4.28\",\n", + " pytorch_version=\"2.0\",\n", + " py_version=\"py310\",\n", + " output_path=f\"{S3_PATH}/output\",\n", + " role=sagemaker_execution_role,\n", + " hyperparameters=hyperparameters,\n", + " metric_definitions=metric_definitions,\n", + " checkpoint_local_path=\"/opt/ml/checkpoints\",\n", + " sagemaker_session=sagemaker_session,\n", + " tags=[{\"Key\": \"project\", \"Value\": \"esm-fine-tuning\"}],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.\n", + "INFO:sagemaker:Creating training-job with name: esm-2-sema-1d-2024-10-20-01-21-29-420\n" + ] + } + ], + "source": [ + "train_s3_uri = S3_PATH + \"/data/train/train_set.csv\"\n", + "test_s3_uri = S3_PATH + \"/data/test/test_set.csv\"\n", + "\n", + "with Run(\n", + " experiment_name=EXPERIMENT_NAME,\n", + " sagemaker_session=sagemaker_session,\n", + ") as run:\n", + " hf_estimator.fit(\n", + " {\n", + " 'train': TrainingInput(s3_data=train_s3_uri, content_type='text/csv'),\n", + " 'test': TrainingInput(s3_data=test_s3_uri, content_type='text/csv')\n", + " },\n", + " wait=False,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can view metrics and debugging information for this run in SageMaker Experiments." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training job name: esm-2-sema-1d-2024-10-20-01-21-29-420\n", + "Training job status: InProgress\n", + "Training job output: None\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
0
TrialComponentNameesm-2-sema-1d-2024-10-20-01-21-29-420-aws-trai...
DisplayNameesm-2-sema-1d-2024-10-20-01-21-29-420-aws-trai...
SourceArnarn:aws:sagemaker:us-east-1:340752820161:train...
SageMaker.ImageUri763104351884.dkr.ecr.us-east-1.amazonaws.com/h...
SageMaker.InstanceCount1.0
SageMaker.InstanceTypeml.p3.2xlarge
SageMaker.VolumeSizeInGB30.0
epochs1.0
gradient_accumulation_steps1.0
learning_rate0.00001
per_device_eval_batch_size1.0
per_device_train_batch_size1.0
sagemaker_container_log_level20.0
sagemaker_job_name\"esm-2-sema-1d-2024-10-20-01-21-29-420\"
sagemaker_program\"train.py\"
sagemaker_region\"us-east-1\"
sagemaker_submit_directory\"s3://sagemaker-us-east-1-340752820161/esm-2-s...
seed42.0
warmup_steps0.0
weight_decay0.0
test - MediaTypetext/csv
test - Values3://sagemaker-us-east-1-340752820161/esm2-sem...
train - MediaTypetext/csv
train - Values3://sagemaker-us-east-1-340752820161/esm2-sem...
SageMaker.DebugHookOutput - MediaTypeNone
SageMaker.DebugHookOutput - Values3://sagemaker-us-east-1-340752820161/esm2-sem...
Trials[Default-Run-Group-esm2-sema-1d-2024-10-20-01-...
Experiments[esm2-sema-1d-2024-10-20-01-21-05]
\n", + "
" + ], + "text/plain": [ + " 0\n", + "TrialComponentName esm-2-sema-1d-2024-10-20-01-21-29-420-aws-trai...\n", + "DisplayName esm-2-sema-1d-2024-10-20-01-21-29-420-aws-trai...\n", + "SourceArn arn:aws:sagemaker:us-east-1:340752820161:train...\n", + "SageMaker.ImageUri 763104351884.dkr.ecr.us-east-1.amazonaws.com/h...\n", + "SageMaker.InstanceCount 1.0\n", + "SageMaker.InstanceType ml.p3.2xlarge\n", + "SageMaker.VolumeSizeInGB 30.0\n", + "epochs 1.0\n", + "gradient_accumulation_steps 1.0\n", + "learning_rate 0.00001\n", + "per_device_eval_batch_size 1.0\n", + "per_device_train_batch_size 1.0\n", + "sagemaker_container_log_level 20.0\n", + "sagemaker_job_name \"esm-2-sema-1d-2024-10-20-01-21-29-420\"\n", + "sagemaker_program \"train.py\"\n", + "sagemaker_region \"us-east-1\"\n", + "sagemaker_submit_directory \"s3://sagemaker-us-east-1-340752820161/esm-2-s...\n", + "seed 42.0\n", + "warmup_steps 0.0\n", + "weight_decay 0.0\n", + "test - MediaType text/csv\n", + "test - Value s3://sagemaker-us-east-1-340752820161/esm2-sem...\n", + "train - MediaType text/csv\n", + "train - Value s3://sagemaker-us-east-1-340752820161/esm2-sem...\n", + "SageMaker.DebugHookOutput - MediaType None\n", + "SageMaker.DebugHookOutput - Value s3://sagemaker-us-east-1-340752820161/esm2-sem...\n", + "Trials [Default-Run-Group-esm2-sema-1d-2024-10-20-01-...\n", + "Experiments [esm2-sema-1d-2024-10-20-01-21-05]" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sagemaker.analytics import ExperimentAnalytics\n", + "\n", + "training_job_details = hf_estimator.latest_training_job.describe()\n", + "print(f\"Training job name: {training_job_details.get('TrainingJobName')}\")\n", + "print(f\"Training job status: {training_job_details.get('TrainingJobStatus')}\")\n", + "print(f\"Training job output: {training_job_details.get('ModelArtifacts')}\")\n", + "\n", + "search_expression = {\n", + " \"Filters\": [\n", + " {\n", + " \"Name\": \"DisplayName\",\n", + " \"Operator\": \"Contains\",\n", + " \"Value\": \"Training\",\n", + " }\n", + " ],\n", + "}\n", + "\n", + "trial_component_analytics = ExperimentAnalytics(\n", + " sagemaker_session=sagemaker_session,\n", + " experiment_name=EXPERIMENT_NAME,\n", + " search_expression=search_expression,\n", + ")\n", + "\n", + "trial_component_analytics.dataframe().T" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Deploy Model as Real-Time Inference Endpoint" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To deploy our endpoint, we call deploy() on our HuggingFace estimator object, passing in our desired number of instances and instance type." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'estimator' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[8], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m predictor \u001b[38;5;241m=\u001b[39m \u001b[43mestimator\u001b[49m\u001b[38;5;241m.\u001b[39mdeploy(initial_instance_count\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, instance_type\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mml.r5.2xlarge\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'estimator' is not defined" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "predictor = hf_estimator.deploy(initial_instance_count=1, instance_type=\"ml.r5.2xlarge\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Try running some known epitopes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Example sequence input for conformational B-cell epitope prediction (Ara h 2)\n", + "sample_sequence = {\n", + " \"sequence\": \"MAKLTILVALALFLLAAHASARQQWELQGDRRCQSQLERANLRPCEQHLMQKIQRDEDSYERDPYSPSQDPYSPSPYDRRGAGSSQHQERCCNELNEFENNQRCMCEALQQIMENQSDRLQGRQQEQQFKRELRNLPQQCGLRAPQRCDLDVESGG\"\n", + "}\n", + "\n", + "# Send the sequence to the deployed SageMaker predictor for epitope prediction\n", + "response = predictor.predict(sample_sequence)\n", + "\n", + "# Print the predicted conformational B-cell epitopes\n", + "print(response)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#Epoch\tTraining Loss\tValidation Loss\tPearson R\tMse\tR2 Score\n", + "#1\t0.212700\t0.150756\t0.251578\t0.173891\t-0.567424\n", + "#2\t0.157400\t0.165494\t0.253576\t0.183997\t-0.658516\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Clean up" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Delete endpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " predictor.delete_endpoint()\n", + "except:\n", + " pass" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Delete S3 data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "bucket = boto_session.resource(\"s3\").Bucket(S3_BUCKET)\n", + "bucket.objects.filter(Prefix=S3_PREFIX).delete()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/sema-1d-sagemaker/scripts/inference.py b/notebooks/sema-1d-sagemaker/scripts/inference.py index d7ca1ff..ebc1072 100644 --- a/notebooks/sema-1d-sagemaker/scripts/inference.py +++ b/notebooks/sema-1d-sagemaker/scripts/inference.py @@ -1,37 +1,129 @@ +import os + +import esm import torch -from transformers import AutoTokenizer, EsmForTokenClassification +from torch import nn +from transformers.modeling_outputs import SequenceClassifierOutput -# Load the model and tokenizer def model_fn(model_dir): - # Load the EsmForTokenClassification model for regression - model = EsmForTokenClassification.from_pretrained( - model_dir, - device_map="auto", - num_labels=1, # Since it's a regression task - ) - tokenizer = AutoTokenizer.from_pretrained(model_dir) + """ + Load the model and ESM batch converter from the model directory. + + Args: + model_dir (str): Directory where the model artifacts are saved. + + Returns: + Tuple[torch.nn.Module, esm.pretrained.Alphabet]: The loaded model and ESM batch converter. + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Define the custom model class + class ESM1vForTokenClassification(nn.Module): + def __init__(self, num_labels=2, pretrained_no=1): + super().__init__() + self.num_labels = num_labels + self.model_name = "esm1v_t33_650M_UR90S_" + str(pretrained_no) + + # Load the pretrained ESM-1v model and alphabet + self.esm1v, self.esm1v_alphabet = esm.pretrained.esm1v_t33_650M_UR90S_1() + self.classifier = nn.Linear(1280, self.num_labels) - return model, tokenizer + def forward(self, token_ids, labels=None): + outputs = self.esm1v.forward(token_ids, repr_layers=[33])[ + "representations" + ][33] + outputs = outputs[:, 1:-1, :] # Remove start and end tokens + logits = self.classifier(outputs) + return SequenceClassifierOutput(logits=logits) + # Initialize and load the model + model = ESM1vForTokenClassification().to(device) -# Prediction function -def predict_fn(data, model_and_tokenizer): - model, tokenizer = model_and_tokenizer + # Load the state_dict from 'model.pth' + model.load_state_dict( + torch.load(os.path.join(model_dir, "model.pth"), map_location=device) + ) model.eval() - # Prepare input data for the model - inputs = data.pop("inputs", data) - encoding = tokenizer(inputs, return_tensors="pt") - encoding = {k: v.to(model.device) for k, v in encoding.items()} + # Initialize the ESM batch converter + _, esm1v_alphabet = esm.pretrained.esm1v_t33_650M_UR90S_1() + batch_converter = esm1v_alphabet.get_batch_converter() + + return model, batch_converter + + +def input_fn(request_body, request_content_type): + """ + Parse the incoming request body. + + Args: + request_body (str): The body of the request. + request_content_type (str): The content type of the request. + + Returns: + str: The protein sequence. + """ + import json + + if request_content_type == "application/json": + data = json.loads(request_body) + if "sequence" in data: + return data["sequence"] + else: + raise ValueError("JSON input must contain 'sequence' key.") + elif request_content_type == "text/plain": + return request_body + else: + raise ValueError(f"Unsupported content type: {request_content_type}") + + +def predict_fn(input_data, model_and_batch_converter): + """ + Perform prediction using the loaded model and batch converter. + + Args: + input_data (str): The protein sequence. + model_and_batch_converter (Tuple[torch.nn.Module, esm.pretrained.Alphabet]): The loaded model and batch converter. + + Returns: + List[float]: Per-residue epitope probabilities. + """ + model, batch_converter = model_and_batch_converter + device = next(model.parameters()).device + + # Prepare the batch + batch = batch_converter([("", input_data)]) + token_ids = batch[2].to(device) - # Run inference with torch.no_grad(): - results = model(**encoding) + outputs = model(token_ids) + logits = outputs.logits + probs = torch.sigmoid(logits) # Apply sigmoid for binary classification + probs = probs.cpu().numpy() + + # Extract per-residue epitope probabilities (class 1 in a binary classification) + epitope_probs = probs[0, :, 1].tolist() + + return epitope_probs + + +def output_fn(prediction, response_content_type): + """ + Format the prediction output. + + Args: + prediction (List[float]): The prediction probabilities. + response_content_type (str): The desired content type of the response. - # For regression, we directly use the logits as the predicted value - predictions = results.logits.cpu().numpy() + Returns: + Tuple[str, str]: The response body and its content type. + """ + import json - return { - "predicted_contact_number": predictions[0].tolist() - } # Return prediction(s) as a list + if response_content_type == "application/json": + return json.dumps({"epitope_probabilities": prediction}), response_content_type + elif response_content_type == "text/plain": + return str(prediction), response_content_type + else: + raise ValueError(f"Unsupported response content type: {response_content_type}") diff --git a/notebooks/sema-1d-sagemaker/scripts/requirements.txt b/notebooks/sema-1d-sagemaker/scripts/requirements.txt index 6a1d080..16253e8 100644 --- a/notebooks/sema-1d-sagemaker/scripts/requirements.txt +++ b/notebooks/sema-1d-sagemaker/scripts/requirements.txt @@ -1,9 +1,7 @@ accelerate -bitsandbytes -datasets +fair-esm evaluate nvidia-ml-py3 -peft scikit-learn transformers torchinfo diff --git a/notebooks/sema-1d-sagemaker/scripts/train.py b/notebooks/sema-1d-sagemaker/scripts/train.py new file mode 100644 index 0000000..741a1c1 --- /dev/null +++ b/notebooks/sema-1d-sagemaker/scripts/train.py @@ -0,0 +1,316 @@ +import argparse +import math +import os +import shutil +import tempfile + +import esm +import pandas as pd +import pynvml +import scipy +import torch +from sklearn.metrics import mean_squared_error, r2_score +from torch import nn +from torch.utils.data import Dataset +from transformers import ( + EvalPrediction, + Trainer, + TrainingArguments, + set_seed, +) +from transformers.modeling_outputs import SequenceClassifierOutput + + +# Define your Dataset class +class PDB_Dataset(Dataset): + """ + A class to represent a suitable dataset for the model. + + Converts original pandas dataframe to model set, + where 'token_ids' are ESM-1v embeddings corresponding to protein sequence (max length 1022 AA) + and 'labels' are contact number values. + """ + + def __init__(self, df, label_type="regression"): + """ + Initialize the dataset. + + Parameters: + df (pandas.DataFrame): DataFrame with two columns: + 0 -- protein sequence in string ('GLVM') or list format + 1 -- contact number values in list format [0, 0.123, ...] + label_type (str): Type of model: 'regression' or 'binary' + """ + self.df = df + _, self.esm1v_alphabet = esm.pretrained.esm1v_t33_650M_UR90S_1() + self.esm1v_batch_converter = self.esm1v_alphabet.get_batch_converter() + self.label_type = label_type + + def __getitem__(self, idx): + item = {} + sequence = "".join(self.df.iloc[idx, 0])[:1022] + _, _, esm1v_batch_tokens = self.esm1v_batch_converter([("", sequence)]) + item["token_ids"] = esm1v_batch_tokens + labels = self.df.iloc[idx, 1][:1022] + # Handle label transformation if necessary + labels = [math.log(t + 1) if t != -100 else -100 for t in labels] + item["labels"] = torch.unsqueeze(torch.FloatTensor(labels), 0).to(torch.float64) + return item + + def __len__(self): + return len(self.df) + + +# Define your Model +class ESM1vForTokenClassification(nn.Module): + def __init__(self, num_labels=2, pretrained_no=1): + super().__init__() + self.num_labels = num_labels + # Load the pretrained ESM model + self.esm1v, self.esm1v_alphabet = esm.pretrained.esm1v_t33_650M_UR90S_1() + self.classifier = nn.Linear(1280, self.num_labels) + + def forward(self, token_ids, labels=None): + outputs = self.esm1v.forward(token_ids, repr_layers=[33])["representations"][33] + outputs = outputs[:, 1:-1, :] # Remove start and end tokens + logits = self.classifier(outputs) + return SequenceClassifierOutput(logits=logits) + + +# Custom MSE Loss with Masking +class MaskedMSELoss(nn.Module): + def __init__(self): + super(MaskedMSELoss, self).__init__() + + def forward(self, inputs, target, mask): + diff2 = ( + torch.flatten(inputs[:, :, 1]) - torch.flatten(target) + ) ** 2.0 * torch.flatten(mask) + result = torch.sum(diff2) / torch.sum(mask) + if torch.sum(mask) == 0: + return torch.sum(diff2) + else: + return result + + +# Custom Trainer with MaskedMSELoss +class MaskedRegressTrainer(Trainer): + def compute_loss(self, model, inputs, return_outputs=False): + labels = inputs.pop("labels") + labels = labels.squeeze().detach().cpu().numpy().tolist() + labels = [math.log(t + 1) if t != -100 else -100 for t in labels] + labels = torch.unsqueeze(torch.FloatTensor(labels), 0).cuda() + masks = ~torch.eq(labels, -100).cuda() + + outputs = model(**inputs) + logits = outputs.logits + + loss_fn = MaskedMSELoss() + loss = loss_fn(logits, labels, masks) + + return (loss, outputs) if return_outputs else loss + + +# Custom Metrics +def compute_metrics_regr(p: EvalPrediction): + preds = p.predictions[:, :, 1] + batch_size, seq_len = preds.shape + out_labels, out_preds = [], [] + + for i in range(batch_size): + for j in range(seq_len): + if p.label_ids[i, j] > -1: + out_labels.append(p.label_ids[i, j]) + out_preds.append(preds[i, j]) + + out_labels_regr = out_labels + + return { + "pearson_r": scipy.stats.pearsonr(out_labels_regr, out_preds)[0], + "mse": mean_squared_error(out_labels_regr, out_preds), + "r2_score": r2_score(out_labels_regr, out_preds), + } + + +# Collate Function +def collator_fn(x): + if len(x) == 1: + return x[0] + print("x:", x) + return x + + +def get_gpu_utilization(): + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + info = pynvml.nvmlDeviceGetMemoryInfo(handle) + return info.used // 1024**2 # Convert bytes to MB + + +def main(args): + set_seed(args.seed) + + # Set the device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Set TORCH_HOME + efs_model_path = "/home/sagemaker-user/user-default-efs/torch_hub" + os.environ["TORCH_HOME"] = efs_model_path + if not os.path.exists(efs_model_path): + os.makedirs(efs_model_path) + + # Find the CSV file in the train and test directories + train_files = [ + os.path.join(args.train_dataset_path, f) + for f in os.listdir(args.train_dataset_path) + if f.endswith(".csv") + ] + test_files = [ + os.path.join(args.eval_dataset_path, f) + for f in os.listdir(args.eval_dataset_path) + if f.endswith(".csv") + ] + + # Assuming only one CSV file is present in each directory + train_df = pd.read_csv(train_files[0]) + test_df = pd.read_csv(test_files[0]) + + # Group by 'pdb_id_chain' and aggregate + train_set = ( + train_df.groupby("pdb_id_chain") + .agg({"resi_pos": list, "resi_aa": list, "contact_number": list}) + .reset_index() + ) + + test_set = ( + test_df.groupby("pdb_id_chain") + .agg({"resi_pos": list, "resi_aa": list, "contact_number_binary": list}) + .reset_index() + ) + + # Initialize datasets + train_ds = PDB_Dataset( + train_set[["resi_aa", "contact_number"]], label_type="regression" + ) + test_ds = PDB_Dataset( + test_set[["resi_aa", "contact_number_binary"]], label_type="regression" + ) + + # Initialize model + model = ESM1vForTokenClassification().to(device) + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize Trainer + training_args = TrainingArguments( + output_dir=tmp_dir, + overwrite_output_dir=True, + num_train_epochs=args.epochs, + per_device_train_batch_size=args.per_device_train_batch_size, + per_device_eval_batch_size=args.per_device_eval_batch_size, + warmup_steps=args.warmup_steps, + learning_rate=args.lr, + weight_decay=args.weight_decay, + gradient_accumulation_steps=args.gradient_accumulation_steps, + logging_dir=f"{tmp_dir}/logs", + logging_strategy="steps", + logging_steps=200, + save_strategy="no", + evaluation_strategy="epoch", + fp16=False, + load_best_model_at_end=False, + metric_for_best_model="eval_accuracy", + greater_is_better=True, + push_to_hub=False, + ) + + trainer = MaskedRegressTrainer( + model=model, + args=training_args, + train_dataset=train_ds, + eval_dataset=test_ds, + data_collator=collator_fn, + compute_metrics=compute_metrics_regr, + ) + + # Start training + trainer.train() + + gpu_memory_use = get_gpu_utilization() + print(f"Max GPU memory use during training: {gpu_memory_use} MB") + + # Save the model's state dictionary + torch.save( + trainer.model.state_dict(), + os.path.join(args.model_output_path, "model.pth"), + ) + print("Model state_dict saved.") + + # copy inference script + os.makedirs("/opt/ml/model/code", exist_ok=True) + shutil.copyfile( + os.path.join(os.path.dirname(__file__), "inference.py"), + "/opt/ml/model/code/inference.py", + ) + shutil.copyfile( + os.path.join(os.path.dirname(__file__), "requirements.txt"), + "/opt/ml/model/code/requirements.txt", + ) + + +def parse_args(): + """Parse the arguments.""" + parser = argparse.ArgumentParser() + + # SageMaker specific arguments + parser.add_argument( + "--train_dataset_path", + type=str, + default=os.environ["SM_CHANNEL_TRAIN"], + help="Path to train dataset.", + ) + parser.add_argument( + "--eval_dataset_path", + type=str, + default=os.environ["SM_CHANNEL_TEST"], + help="Path to evaluation dataset.", + ) + parser.add_argument( + "--model_output_path", + type=str, + default=os.environ["SM_MODEL_DIR"], + help="Path to model output folder.", + ) + + # Hyperparameters + parser.add_argument( + "--epochs", type=int, default=1, help="Number of epochs to train for." + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=1, + help="Batch size to use for training.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=1, + help="Batch size to use for evaluation.", + ) + parser.add_argument( + "--lr", type=float, default=1e-5, help="Learning rate to use for training." + ) + parser.add_argument("--warmup_steps", type=int, default=0) + parser.add_argument("--weight_decay", type=float, default=0.0) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument( + "--seed", type=int, default=42, help="Seed to use for training." + ) + + args = parser.parse_known_args() + return args + + +if __name__ == "__main__": + args, _ = parse_args() + main(args) From 40b72022960b161dd49b8150d7a852112a8b6eb5 Mon Sep 17 00:00:00 2001 From: ElektrikSpark Date: Sun, 20 Oct 2024 07:04:08 +0000 Subject: [PATCH 5/5] fix: updated sema 1d notebook, finalizing project --- .env.example | 4 +++- .../src/app/services/background_tasks.py | 13 ------------ .../sema-1d-sagemaker/scripts/inference.py | 21 +++++++++++++++++++ .../scripts/requirements.txt | 13 ++++++------ notebooks/sema-1d-sagemaker/scripts/train.py | 6 +++++- 5 files changed, 36 insertions(+), 21 deletions(-) diff --git a/.env.example b/.env.example index a6847bd..c1f3c32 100644 --- a/.env.example +++ b/.env.example @@ -5,10 +5,12 @@ # If you are cloning this repo, create a copy of this file named `.env` and populate it with your secrets. # The database URL is used to connect to your Supabase database. -POSTGRES_URL="postgres://postgres.[USERNAME]:[PASSWORD]@aws-0-eu-central-1.pooler.supabase.com:6543/postgres?workaround=supabase-pooler.vercel" +POSTGRES_URL="postgres://postgres.[USERNAME]:[PASSWORD]@aws-0-us-east-1.pooler.supabase.com:6543/postgres?workaround=supabase-pooler.vercel" # FastAPI NEXT_PUBLIC_FASTAPI_URL="http://127.0.0.1:8000" +NEXT_PUBLIC_FASTAPI_STAGE_URL="" +NEXT_PUBLIC_USE_LAMBDA_API="false" # Set to "true" when you want to test Lambda # Supabase NEXT_PUBLIC_SUPABASE_URL="" diff --git a/apps/fastapi/src/app/services/background_tasks.py b/apps/fastapi/src/app/services/background_tasks.py index bfaaae0..7e59888 100644 --- a/apps/fastapi/src/app/services/background_tasks.py +++ b/apps/fastapi/src/app/services/background_tasks.py @@ -42,19 +42,6 @@ async def process_and_update_prediction( # Update job status to 'running' await crud_job.update_status(db=db, id=job_id, status="running") - # Sample input - glp_1_receptor = "MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS" - sample = {"inputs": glp_1_receptor} - - # Run prediction pipeline - # predictions = run( - # input_objects=[sample], - # endpoint_name=settings.SAGEMAKER_ENDPOINT_NAME, - # model_type="single", - # ) - - # logger.info(predictions) - # Process the prediction based on its type if prediction_type == "conformational-b": if is_structure_based: diff --git a/notebooks/sema-1d-sagemaker/scripts/inference.py b/notebooks/sema-1d-sagemaker/scripts/inference.py index ebc1072..643f14c 100644 --- a/notebooks/sema-1d-sagemaker/scripts/inference.py +++ b/notebooks/sema-1d-sagemaker/scripts/inference.py @@ -1,3 +1,4 @@ +import logging import os import esm @@ -5,6 +6,8 @@ from torch import nn from transformers.modeling_outputs import SequenceClassifierOutput +logger = logging.getLogger(__name__) + def model_fn(model_dir): """ @@ -16,7 +19,10 @@ def model_fn(model_dir): Returns: Tuple[torch.nn.Module, esm.pretrained.Alphabet]: The loaded model and ESM batch converter. """ + logger.debug(f"Loading model from {model_dir}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.debug(f"Using device: {device}") # Define the custom model class class ESM1vForTokenClassification(nn.Module): @@ -28,28 +34,38 @@ def __init__(self, num_labels=2, pretrained_no=1): # Load the pretrained ESM-1v model and alphabet self.esm1v, self.esm1v_alphabet = esm.pretrained.esm1v_t33_650M_UR90S_1() self.classifier = nn.Linear(1280, self.num_labels) + logger.debug("ESM model initialized") def forward(self, token_ids, labels=None): outputs = self.esm1v.forward(token_ids, repr_layers=[33])[ "representations" ][33] + logger.debug( + f"Model forward pass completed, token_ids shape: {token_ids.shape}" + ) outputs = outputs[:, 1:-1, :] # Remove start and end tokens logits = self.classifier(outputs) return SequenceClassifierOutput(logits=logits) # Initialize and load the model model = ESM1vForTokenClassification().to(device) + logger.debug("Model instantiated and moved to device") # Load the state_dict from 'model.pth' model.load_state_dict( torch.load(os.path.join(model_dir, "model.pth"), map_location=device) ) + logger.debug("Model weights loaded from model.pth") + model.eval() + logger.debug("Model set to evaluation mode") # Initialize the ESM batch converter _, esm1v_alphabet = esm.pretrained.esm1v_t33_650M_UR90S_1() batch_converter = esm1v_alphabet.get_batch_converter() + logger.debug("Batch converter initialized") + return model, batch_converter @@ -90,11 +106,15 @@ def predict_fn(input_data, model_and_batch_converter): List[float]: Per-residue epitope probabilities. """ model, batch_converter = model_and_batch_converter + logger.debug(f"Received input_data: {input_data}") + device = next(model.parameters()).device + logger.debug(f"Model is using device: {device}") # Prepare the batch batch = batch_converter([("", input_data)]) token_ids = batch[2].to(device) + logger.debug(f"Batch prepared, token_ids shape: {token_ids.shape}") with torch.no_grad(): outputs = model(token_ids) @@ -104,6 +124,7 @@ def predict_fn(input_data, model_and_batch_converter): # Extract per-residue epitope probabilities (class 1 in a binary classification) epitope_probs = probs[0, :, 1].tolist() + logger.debug(f"Predicted probabilities: {epitope_probs}") return epitope_probs diff --git a/notebooks/sema-1d-sagemaker/scripts/requirements.txt b/notebooks/sema-1d-sagemaker/scripts/requirements.txt index 16253e8..8a1ab7d 100644 --- a/notebooks/sema-1d-sagemaker/scripts/requirements.txt +++ b/notebooks/sema-1d-sagemaker/scripts/requirements.txt @@ -1,7 +1,8 @@ -accelerate +accelerate==0.24.1 fair-esm -evaluate -nvidia-ml-py3 -scikit-learn -transformers -torchinfo +evaluate==0.4.1 +huggingface-hub +nvidia-ml-py3==7.352.0 +scikit-learn==1.3.2 +transformers==4.34.1 +torchinfo==1.8.0 diff --git a/notebooks/sema-1d-sagemaker/scripts/train.py b/notebooks/sema-1d-sagemaker/scripts/train.py index 741a1c1..a7bb95f 100644 --- a/notebooks/sema-1d-sagemaker/scripts/train.py +++ b/notebooks/sema-1d-sagemaker/scripts/train.py @@ -66,8 +66,12 @@ class ESM1vForTokenClassification(nn.Module): def __init__(self, num_labels=2, pretrained_no=1): super().__init__() self.num_labels = num_labels + self.model_name = ( + esm.pretrained.esm2_t33_650M_UR50D() + ) # load_model_and_alphabet_hub(self.model_name) + # Load the pretrained ESM model - self.esm1v, self.esm1v_alphabet = esm.pretrained.esm1v_t33_650M_UR90S_1() + self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D() self.classifier = nn.Linear(1280, self.num_labels) def forward(self, token_ids, labels=None):