diff --git a/notebooks/TCR_epitope_prediction_NewEpitope_NewTCR_LOOCV.ipynb b/notebooks/TCR_epitope_prediction_NewEpitope_NewTCR_LOOCV.ipynb new file mode 100644 index 0000000..53a9e8a --- /dev/null +++ b/notebooks/TCR_epitope_prediction_NewEpitope_NewTCR_LOOCV.ipynb @@ -0,0 +1,3173 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "-NSHxGYaV7vc" + }, + "outputs": [], + "source": [ + "%%capture\n", + "import os\n", + "from google.colab import drive\n", + "drive.mount('/content/drive')\n", + "path = \"/content/drive/MyDrive/COLAB/TCR_projects\"\n", + "os.chdir(path)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "be5dH_JwZLlu" + }, + "outputs": [], + "source": [ + "import random\n", + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "from sklearn.model_selection import GridSearchCV\n", + "from sklearn.linear_model import LogisticRegression\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.neural_network import MLPClassifier\n", + "from sklearn.preprocessing import StandardScaler\n", + "\n", + "############################## Negative sampling: random shuffling ##############################\n", + "\n", + "#The random shuffling strategy, also known as \"Random TCR\" or \"Random Epitope\" depending on which element is shuffled, involves creating artificial negative samples by randomly pairing TCRs with epitopes that are not their known binding partners.\n", + "#Separation of Data: First, the positive TCR-epitope pairs are split into training and test sets.\n", + "#Random Pairing: Within each set (train and test), TCRs or epitopes are randomly shuffled to create new pairs that are assumed to be non-binding.\n", + "#Negative Sample Generation: These new random pairs are labeled as negative samples.\n", + "\n", + "# write a script to pair TCRs and epitopes, excluding their true binding partners for mat_train_tab\n", + "# mat_train_tab has columns epitope_aa, tcr_full and value. Value of 1 mean binding, value of 0 means non binding.\n", + "\n", + "def generate_negative_samples(df,tcr_features,num_epitopes=100, num_tcrs=10):\n", + " # generate negatives for both epitope and TCRs\n", + " negative_samples = []\n", + "\n", + " # Ensure the number of samples does not exceed the number of unique elements\n", + " unique_epitopes = list(df['epitope'].unique())\n", + " unique_tcrs = list(df[tcr_features].unique())\n", + "\n", + " sampled_epitopes = random.sample(unique_epitopes, min(num_epitopes, len(unique_epitopes)))\n", + " sampled_tcrs = random.sample(unique_tcrs, min(num_tcrs, len(unique_tcrs)))\n", + "\n", + " for epitope in sampled_epitopes:\n", + " for tcr in sampled_tcrs:\n", + " if df[(df['epitope'] == epitope) & (df[tcr_features] == tcr)].empty:\n", + " negative_samples.append({'epitope': epitope, tcr_features: tcr, 'value': 0})\n", + "\n", + " return pd.DataFrame(negative_samples)\n", + "\n", + "def generate_negative_samples_epitope(df,tcr_features, epitope, num_tcrs=10):\n", + " # generate negatives for TCRs only.\n", + " negative_samples = []\n", + "\n", + " # Ensure the number of samples does not exceed the number of unique elements\n", + " unique_epitopes = epitope\n", + " unique_tcrs = list(df[tcr_features].unique())\n", + "\n", + " #sampled_epitopes = random.sample(unique_epitopes, min(num_epitopes, len(unique_epitopes)))\n", + " sampled_tcrs = random.sample(unique_tcrs, min(num_tcrs, len(unique_tcrs)))\n", + "\n", + " #for epitope in sampled_epitopes:\n", + " for tcr in sampled_tcrs:\n", + " if df[(df['epitope'] == epitope) & (df[tcr_features] == tcr)].empty:\n", + " negative_samples.append({'epitope': epitope, tcr_features: tcr, 'value': 0})\n", + "\n", + " return pd.DataFrame(negative_samples)\n", + "\n", + "def preprocess_features(feat, res, train_indices, test_indices):\n", + " x_train = feat.iloc[train_indices, :]\n", + " y_train = res[train_indices]\n", + " x_test = feat.iloc[test_indices, :]\n", + " y_test = res[test_indices]\n", + " # scale the data\n", + " scaler = StandardScaler().fit(x_train)\n", + " x_train = pd.DataFrame(scaler.transform(x_train), index=x_train.index, columns=x_train.columns)\n", + " x_test = pd.DataFrame(scaler.transform(x_test), index=x_test.index, columns=x_test.columns)\n", + " return x_train, y_train, x_test, y_test\n", + "\n", + "def train (algo, x_train, y_train, x_test):\n", + " model = None\n", + " if algo == \"sklearn_mlp\":\n", + " model = MLPClassifier(\n", + " hidden_layer_sizes=(100, 50, 10), activation='relu',\n", + " learning_rate_init=0.001, alpha=0.01, max_iter=10,\n", + " early_stopping=False, validation_fraction=0.1\n", + " )\n", + " elif algo == \"sklearn_randomforest\":\n", + " model = RandomForestClassifier(\n", + " n_estimators=1000, max_depth=10, oob_score=True\n", + " )\n", + " elif algo == \"sklearn_logit\":\n", + " model = LogisticRegression(\n", + " C=0.5, solver=\"saga\", penalty=\"elasticnet\",\n", + " l1_ratio=0.5, class_weight=\"balanced\"\n", + " )\n", + "\n", + " if model:\n", + " model.fit(x_train, y_train)\n", + " x_train_proba = model.predict_proba(x_train)[:, 1]\n", + " x_test_proba = model.predict_proba(x_test)[:, 1]\n", + " return x_train_proba, x_test_proba, model\n", + " else:\n", + " raise ValueError(\"Specified algorithm is not supported\")\n", + "\n", + "\n", + "def get_feature_importance(model, algo):\n", + " # get model weights for each algorithm\n", + " temp = []\n", + " if algo == \"sklearn_logit\":\n", + " temp = model.coef_\n", + " if algo == \"sklearn_randomforest\":\n", + " temp = model.feature_importances_\n", + " if len(temp) == 1:\n", + " temp = temp.transpose()\n", + " return temp\n", + "\n", + "def aggregate_feature_importance(imp, f_name):\n", + " # aggregate by feature and iteration\n", + " imp_agg = imp.groupby(['feature', 'iteration'])['importance'].mean().reset_index()\n", + " # aggregate by feature and compute mean and sd across all iterations for each feature\n", + " imp_agg = imp_agg.groupby('feature').agg({'importance': ['mean', 'std']})\n", + " imp_agg.columns = ['importance', 'std']\n", + " imp_agg.reset_index(inplace=True)\n", + " imp_agg.columns = ['feature_name', 'importance', 'std']\n", + " return imp_agg\n", + "\n", + "\n", + "########################################## run ML ###########################################\n", + "\n", + "def run_ML(epitope,epitope_embeddings,tcr_embeddings):\n", + " print(epitope)\n", + "\n", + " run_ite = 0 # Initialize iteration number of each repetition, outer loop\n", + " probability_all = pd.DataFrame()\n", + " importance_all = pd.DataFrame()\n", + " for rep in range(repetition):\n", + " run_ite += 1\n", + "\n", + " # split the dataset\n", + " mat_train = mat.drop(epitope, axis=1)\n", + " mat_train_tab = mat_train.stack().reset_index().rename(columns={0:'value'})\n", + "\n", + " mat_test = pd.DataFrame(mat[epitope])\n", + " mat_test_tab = mat_test.stack().reset_index().rename(columns={0:'value'})\n", + " mat_test_tab.columns = mat_train_tab.columns\n", + "\n", + " ############################## Negative sampling: random shuffling ##############################\n", + "\n", + " # pos / neg ratio\n", + " mat_train_pos = len(mat_train_tab)\n", + " epitope_number = len(np.unique(mat_train_tab[\"epitope\"]))\n", + " negative_df = generate_negative_samples(mat_train_tab, tcr_features=tcr_features,\n", + " num_epitopes=epitope_number, num_tcrs=ratio*round(mat_train_pos/epitope_number))\n", + " mat_train_tab = pd.concat([mat_train_tab[negative_df.columns], negative_df], ignore_index=True)\n", + "\n", + "\n", + " max_pos_cases = 1000 # max positive cases for test set\n", + " mat_test_tab = mat_test_tab.sample(n=min(max_pos_cases, len(mat_test_tab)))\n", + " mat_test_pos = np.min([max_pos_cases, len(mat_test_tab)])\n", + "\n", + " # LOOCV, so TCRs have to be taken from training set\n", + " negative_df = generate_negative_samples_epitope(mat_train_tab, tcr_features, epitope=epitope, num_tcrs = ratio*mat_test_pos )\n", + " mat_test_tab = pd.concat([mat_test_tab[negative_df.columns], negative_df], ignore_index=True)\n", + " # remove the TCRs of the test set from the training set\n", + " mat_train_tab = mat_train_tab.loc[~mat_train_tab[tcr_features].isin(mat_test_tab[tcr_features]), :]\n", + "\n", + " print(mat_train_tab.value.value_counts() )\n", + " print(mat_test_tab.value.value_counts() )\n", + "\n", + "\n", + " def get_embeddings(row):\n", + " epitope = epitope_embeddings.loc[row['epitope']].values\n", + " tcr = tcr_embeddings.loc[row[tcr_features]].values\n", + " return np.concatenate((epitope, tcr))\n", + "\n", + " ################# training set features\n", + " features_train = mat_train_tab.apply(get_embeddings, axis=1)\n", + " features_train = pd.DataFrame(features_train.tolist(), index=features_train.index)\n", + " features_train.index = mat_train_tab[\"epitope\"] + \"_\" + mat_train_tab[tcr_features]\n", + " features_train.columns = epitope_embeddings.columns.tolist() + tcr_embeddings.columns.tolist()\n", + " ## add other information\n", + " df_encoded_TCR_subset = df_encoded_TCR.loc[mat_train_tab[tcr_features], : ]\n", + " df_encoded_epitope_subset = df_encoded_epitope.loc[mat_train_tab[\"epitope\"], : ]\n", + "\n", + " ## combine\n", + " # \"ESM3 + VJ genes\" \"all features\" \"ESMonly\" \"withoutESM\"\n", + " if features_name == \"ESM3 + VJ genes\":\n", + " features_train_all = pd.concat([features_train.reset_index(drop=True), df_encoded_TCR_subset.reset_index(drop=True)], axis=1)\n", + " if features_name == \"all features\":\n", + " features_train_all = pd.concat([features_train.reset_index(drop=True), df_encoded_TCR_subset.reset_index(drop=True), df_encoded_epitope_subset.reset_index(drop=True)], axis=1)\n", + " if features_name == \"ESMonly\":\n", + " features_train_all = features_train\n", + " if features_name == \"withoutESM\":\n", + " features_train_all = pd.concat([df_encoded_TCR_subset.reset_index(drop=True), df_encoded_epitope_subset.reset_index(drop=True)], axis=1)\n", + "\n", + " features_train_all.index = features_train.index\n", + "\n", + " ################# test set features\n", + " features_test = mat_test_tab.apply(get_embeddings, axis=1)\n", + " features_test = pd.DataFrame(features_test.tolist(), index=features_test.index)\n", + " features_test.index = mat_test_tab[\"epitope\"] + \"_\" + mat_test_tab[tcr_features]\n", + " features_test.columns = epitope_embeddings.columns.tolist() + tcr_embeddings.columns.tolist()\n", + " ## add other information\n", + " df_encoded_TCR_subset = df_encoded_TCR.loc[mat_test_tab[tcr_features], : ]\n", + " df_encoded_epitope_subset = df_encoded_epitope.loc[mat_test_tab[\"epitope\"], : ]\n", + "\n", + " ## combine\n", + " # \"ESM3 + VJ genes\" \"all features\" \"ESMonly\" \"withoutESM\"\n", + " if features_name == \"ESM3 + VJ genes\":\n", + " features_test_all = pd.concat([features_test.reset_index(drop=True), df_encoded_TCR_subset.reset_index(drop=True)], axis=1)\n", + " if features_name == \"all features\":\n", + " features_test_all = pd.concat([features_test.reset_index(drop=True), df_encoded_TCR_subset.reset_index(drop=True), df_encoded_epitope_subset.reset_index(drop=True)], axis=1)\n", + " if features_name == \"ESMonly\":\n", + " features_test_all = features_test\n", + " if features_name == \"withoutESM\":\n", + " features_test_all = pd.concat([df_encoded_TCR_subset.reset_index(drop=True), df_encoded_epitope_subset.reset_index(drop=True)], axis=1)\n", + "\n", + " features_test_all.index = features_test.index\n", + "\n", + " RES_train = mat_train_tab.value.tolist()\n", + " RES_test = mat_test_tab.value.tolist()\n", + "\n", + " ############################################ run ML ############################################\n", + " X_train = features_train_all\n", + " X_test = features_test_all\n", + " y_train = RES_train\n", + " y_test = RES_test\n", + "\n", + " X_train.columns = X_train.columns.astype(str)\n", + " X_test.columns = X_test.columns.astype(str)\n", + "\n", + " train_prob, test_prob, model = train (algorithm, X_train, y_train, X_test)\n", + "\n", + " p_test = pd.DataFrame(\n", + " { 'split': \"test\",\n", + " 'epitope': epitope,\n", + " 'iteration': [run_ite] * len(y_test),\n", + " 'sample': X_test.index,\n", + " 'predicted_prob': test_prob,\n", + " 'RealClass': y_test\n", + " }\n", + " )\n", + "\n", + " p_train = pd.DataFrame(\n", + " { 'split': \"train\",\n", + " 'epitope': epitope,\n", + " 'iteration': [run_ite] * len(y_train),\n", + " 'sample': X_train.index,\n", + " 'predicted_prob': train_prob,\n", + " 'RealClass': y_train\n", + " }\n", + " )\n", + "\n", + " p = pd.concat([p_test, p_train], ignore_index=True)\n", + " probability_all = pd.concat([probability_all, p], ignore_index=True)\n", + "\n", + "\n", + " # Get feature importance from the model\n", + " importance_values = get_feature_importance(model=model, algo=algorithm)\n", + " importance = pd.DataFrame({\n", + " 'iteration': [run_ite] * len(importance_values),\n", + " 'feature': X_train.columns,\n", + " 'importance': importance_values.flatten()\n", + " })\n", + "\n", + " # append to the importance_all list\n", + " importance_all = pd.concat([importance_all, importance], ignore_index=True)\n", + "\n", + " return probability_all, importance_all\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "from sklearn.metrics import roc_curve, auc\n", + "from numpy import interp\n", + "\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt\n", + "\n", + "def plot_ROC_curve(probability):\n", + " aucs_all = []\n", + " mean_fpr = np.linspace(0, 1, 100)\n", + " for i in range(1, len(probability[\"iteration\"].unique()) + 1):\n", + " pred_run = probability[probability[\"iteration\"] == i]\n", + " epitopes = np.unique( [i.split('_', 1)[0] for i in pred_run[\"sample\"]] )\n", + "\n", + " # plot an average ROC curve across all runs. Values are interpolated.\n", + " tprs = []\n", + " aucs = []\n", + " for epitope in epitopes:\n", + "\n", + " # select epitope\n", + " pred_run = probability[probability[\"iteration\"] == i]\n", + " pred_run = pred_run[pred_run['sample'].str.contains(epitope)]\n", + "\n", + " fpr, tpr, thresh = roc_curve(pred_run[\"RealClass\"], pred_run[\"predicted_prob\"])\n", + " interpolated_tpr = interp(mean_fpr, fpr, tpr)\n", + " interpolated_tpr[0] = 0.0\n", + " roc_auc = auc(fpr, tpr)\n", + " tprs.append(interpolated_tpr)\n", + " aucs = pd.DataFrame([epitope, roc_auc]).transpose()\n", + "\n", + " if len(aucs_all)== 0:\n", + " aucs_all = aucs\n", + " else:\n", + " aucs_all = pd.concat([aucs_all, aucs])\n", + "\n", + " aucs_all.columns = [\"epitope\",\"AUC\"]\n", + " aucs_all = aucs_all.groupby('epitope').agg('mean')\n", + "\n", + " auc_values = aucs_all.AUC\n", + " print(auc_values)\n", + "\n", + "\n", + " # Assuming 'auc_values' is a pandas Series\n", + " plt.figure(figsize=(3, 5))\n", + " sns.boxplot(y=auc_values, color='lightblue', width=0.4) # Adjust width here\n", + " sns.swarmplot(y=auc_values, color='darkred', size=6) # Adjust size here\n", + " plt.xlabel('AUC')\n", + " plt.title('AUC Scores')\n", + " plt.ylim(0, 1) # Set y-axis limits to 0-1\n", + " # save to pdf\n", + " os.makedirs(\"fig\", exist_ok=True)\n", + " plt.savefig(\"fig/auc_values_epitope_boxplot_\"+features_name+\"_\"+algorithm+\"_N_TCRs\"+str(N_TCRs)+\"_\"+split+\"_species_\"+species+\"_MHC_\"+MHC_class+\".pdf\", format='pdf', bbox_inches='tight')\n", + " plt.show()\n", + "\n", + " print(f\"Mean AUC: {np.mean(auc_values)}\")\n", + " print(f\"Median AUC: {np.median(auc_values)}\")\n", + "\n", + " return auc_values\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "eB-K1R38U_Y3", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 548 + }, + "outputId": "bbc6d27b-f8f6-4614-b0b3-a3741ecaee12" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " epitope cdr3_TRA \\\n", + "DIYKGVYQFKSV_CAGGADRLTFCASSPAGNTLYF DIYKGVYQFKSV CAGGADRLTF \n", + "DIYKGVYQFKSV_CAASGGSNYNVLYFCAWSLWGGPSAETLYF DIYKGVYQFKSV CAASGGSNYNVLYF \n", + "DIYKGVYQFKSV_CAASYNYAQGLTFCASRDWGGRQDTQYF DIYKGVYQFKSV CAASYNYAQGLTF \n", + "DIYKGVYQFKSV_CAAQTGNYKYVFCASGDAGTGQDTQYF DIYKGVYQFKSV CAAQTGNYKYVF \n", + "DIYKGVYQFKSV_CAASLTGGYKVVFCAWRTDNQDTQYF DIYKGVYQFKSV CAASLTGGYKVVF \n", + "... ... ... \n", + "GMGPLLATV_CAVLNNARLMFCASSVDRVADTQYF GMGPLLATV CAVLNNARLMF \n", + "GMGPLLATV_CATDNDMRFCASSFGPDEQYF GMGPLLATV CATDNDMRF \n", + "GMGPLLATV_CAYRSFNNNDMRFCASRSRGGHSPLHF GMGPLLATV CAYRSFNNNDMRF \n", + "GMGPLLATV_CAMTSFQKLVFCASSLRGEKNNYGYTF GMGPLLATV CAMTSFQKLVF \n", + "GMGPLLATV_CAVLNNARLMFCASSVDRVADTQYF GMGPLLATV CAVLNNARLMF \n", + "\n", + " cdr3_TRB TRAV \\\n", + "DIYKGVYQFKSV_CAGGADRLTFCASSPAGNTLYF CASSPAGNTLYF TRAV14-3 \n", + "DIYKGVYQFKSV_CAASGGSNYNVLYFCAWSLWGGPSAETLYF CAWSLWGGPSAETLYF TRAV14N-3 \n", + "DIYKGVYQFKSV_CAASYNYAQGLTFCASRDWGGRQDTQYF CASRDWGGRQDTQYF TRAV14N-3 \n", + "DIYKGVYQFKSV_CAAQTGNYKYVFCASGDAGTGQDTQYF CASGDAGTGQDTQYF TRAV14D-3-DV8 \n", + "DIYKGVYQFKSV_CAASLTGGYKVVFCAWRTDNQDTQYF CAWRTDNQDTQYF TRAV14N-3 \n", + "... ... ... \n", + "GMGPLLATV_CAVLNNARLMFCASSVDRVADTQYF CASSVDRVADTQYF TRAV12-2 \n", + "GMGPLLATV_CATDNDMRFCASSFGPDEQYF CASSFGPDEQYF NaN \n", + "GMGPLLATV_CAYRSFNNNDMRFCASRSRGGHSPLHF CASRSRGGHSPLHF NaN \n", + "GMGPLLATV_CAMTSFQKLVFCASSLRGEKNNYGYTF CASSLRGEKNNYGYTF TRAV39 \n", + "GMGPLLATV_CAVLNNARLMFCASSVDRVADTQYF CASSVDRVADTQYF NaN \n", + "\n", + " TRAJ TRBV \\\n", + "DIYKGVYQFKSV_CAGGADRLTFCASSPAGNTLYF TRAJ45 TRBV14 \n", + "DIYKGVYQFKSV_CAASGGSNYNVLYFCAWSLWGGPSAETLYF TRAJ21 TRBV31 \n", + "DIYKGVYQFKSV_CAASYNYAQGLTFCASRDWGGRQDTQYF TRAJ26 TRBV13-3 \n", + "DIYKGVYQFKSV_CAAQTGNYKYVFCASGDAGTGQDTQYF TRAJ40 TRBV12-2+TRBV13-2 \n", + "DIYKGVYQFKSV_CAASLTGGYKVVFCAWRTDNQDTQYF TRAJ12 TRBV31 \n", + "... ... ... \n", + "GMGPLLATV_CAVLNNARLMFCASSVDRVADTQYF NaN TRBV27 \n", + "GMGPLLATV_CATDNDMRFCASSFGPDEQYF NaN TRBV13 \n", + "GMGPLLATV_CAYRSFNNNDMRFCASRSRGGHSPLHF NaN TRBV13 \n", + "GMGPLLATV_CAMTSFQKLVFCASSLRGEKNNYGYTF NaN TRBV13 \n", + "GMGPLLATV_CAVLNNARLMFCASSVDRVADTQYF NaN TRBV27 \n", + "\n", + " TRBJ MHC MHC_class \\\n", + "DIYKGVYQFKSV_CAGGADRLTFCASSPAGNTLYF TRBJ1-3 H2-IAb MHCII \n", + "DIYKGVYQFKSV_CAASGGSNYNVLYFCAWSLWGGPSAETLYF TRBJ2-3 H2-IAb MHCII \n", + "DIYKGVYQFKSV_CAASYNYAQGLTFCASRDWGGRQDTQYF TRBJ2-5 H2-IAb MHCII \n", + "DIYKGVYQFKSV_CAAQTGNYKYVFCASGDAGTGQDTQYF TRBJ2-5 H2-IAb MHCII \n", + "DIYKGVYQFKSV_CAASLTGGYKVVFCAWRTDNQDTQYF TRBJ2-5 H2-IAb MHCII \n", + "... ... ... ... \n", + "GMGPLLATV_CAVLNNARLMFCASSVDRVADTQYF NaN HLA-A*02:01 MHCI \n", + "GMGPLLATV_CATDNDMRFCASSFGPDEQYF NaN HLA-A*02:01 MHCI \n", + "GMGPLLATV_CAYRSFNNNDMRFCASRSRGGHSPLHF NaN HLA-A*02:01 MHCI \n", + "GMGPLLATV_CAMTSFQKLVFCASSLRGEKNNYGYTF NaN HLA-A*02:01 MHCI \n", + "GMGPLLATV_CAVLNNARLMFCASSVDRVADTQYF NaN HLA-A*02:01 MHCI \n", + "\n", + " species \\\n", + "DIYKGVYQFKSV_CAGGADRLTFCASSPAGNTLYF MusMusculus \n", + "DIYKGVYQFKSV_CAASGGSNYNVLYFCAWSLWGGPSAETLYF MusMusculus \n", + "DIYKGVYQFKSV_CAASYNYAQGLTFCASRDWGGRQDTQYF MusMusculus \n", + "DIYKGVYQFKSV_CAAQTGNYKYVFCASGDAGTGQDTQYF MusMusculus \n", + "DIYKGVYQFKSV_CAASLTGGYKVVFCAWRTDNQDTQYF MusMusculus \n", + "... ... \n", + "GMGPLLATV_CAVLNNARLMFCASSVDRVADTQYF HomoSapiens \n", + "GMGPLLATV_CATDNDMRFCASSFGPDEQYF HomoSapiens \n", + "GMGPLLATV_CAYRSFNNNDMRFCASRSRGGHSPLHF HomoSapiens \n", + "GMGPLLATV_CAMTSFQKLVFCASSLRGEKNNYGYTF HomoSapiens \n", + "GMGPLLATV_CAVLNNARLMFCASSVDRVADTQYF HomoSapiens \n", + "\n", + " cdr3 \\\n", + "DIYKGVYQFKSV_CAGGADRLTFCASSPAGNTLYF CAGGADRLTFCASSPAGNTLYF \n", + "DIYKGVYQFKSV_CAASGGSNYNVLYFCAWSLWGGPSAETLYF CAASGGSNYNVLYFCAWSLWGGPSAETLYF \n", + "DIYKGVYQFKSV_CAASYNYAQGLTFCASRDWGGRQDTQYF CAASYNYAQGLTFCASRDWGGRQDTQYF \n", + "DIYKGVYQFKSV_CAAQTGNYKYVFCASGDAGTGQDTQYF CAAQTGNYKYVFCASGDAGTGQDTQYF \n", + "DIYKGVYQFKSV_CAASLTGGYKVVFCAWRTDNQDTQYF CAASLTGGYKVVFCAWRTDNQDTQYF \n", + "... ... \n", + "GMGPLLATV_CAVLNNARLMFCASSVDRVADTQYF CAVLNNARLMFCASSVDRVADTQYF \n", + "GMGPLLATV_CATDNDMRFCASSFGPDEQYF CATDNDMRFCASSFGPDEQYF \n", + "GMGPLLATV_CAYRSFNNNDMRFCASRSRGGHSPLHF CAYRSFNNNDMRFCASRSRGGHSPLHF \n", + "GMGPLLATV_CAMTSFQKLVFCASSLRGEKNNYGYTF CAMTSFQKLVFCASSLRGEKNNYGYTF \n", + "GMGPLLATV_CAVLNNARLMFCASSVDRVADTQYF CAVLNNARLMFCASSVDRVADTQYF \n", + "\n", + " value \n", + "DIYKGVYQFKSV_CAGGADRLTFCASSPAGNTLYF 1 \n", + "DIYKGVYQFKSV_CAASGGSNYNVLYFCAWSLWGGPSAETLYF 1 \n", + "DIYKGVYQFKSV_CAASYNYAQGLTFCASRDWGGRQDTQYF 1 \n", + "DIYKGVYQFKSV_CAAQTGNYKYVFCASGDAGTGQDTQYF 1 \n", + "DIYKGVYQFKSV_CAASLTGGYKVVFCAWRTDNQDTQYF 1 \n", + "... ... \n", + "GMGPLLATV_CAVLNNARLMFCASSVDRVADTQYF 1 \n", + "GMGPLLATV_CATDNDMRFCASSFGPDEQYF 1 \n", + "GMGPLLATV_CAYRSFNNNDMRFCASRSRGGHSPLHF 1 \n", + "GMGPLLATV_CAMTSFQKLVFCASSLRGEKNNYGYTF 1 \n", + "GMGPLLATV_CAVLNNARLMFCASSVDRVADTQYF 1 \n", + "\n", + "[17676 rows x 12 columns]" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epitopecdr3_TRAcdr3_TRBTRAVTRAJTRBVTRBJMHCMHC_classspeciescdr3value
DIYKGVYQFKSV_CAGGADRLTFCASSPAGNTLYFDIYKGVYQFKSVCAGGADRLTFCASSPAGNTLYFTRAV14-3TRAJ45TRBV14TRBJ1-3H2-IAbMHCIIMusMusculusCAGGADRLTFCASSPAGNTLYF1
DIYKGVYQFKSV_CAASGGSNYNVLYFCAWSLWGGPSAETLYFDIYKGVYQFKSVCAASGGSNYNVLYFCAWSLWGGPSAETLYFTRAV14N-3TRAJ21TRBV31TRBJ2-3H2-IAbMHCIIMusMusculusCAASGGSNYNVLYFCAWSLWGGPSAETLYF1
DIYKGVYQFKSV_CAASYNYAQGLTFCASRDWGGRQDTQYFDIYKGVYQFKSVCAASYNYAQGLTFCASRDWGGRQDTQYFTRAV14N-3TRAJ26TRBV13-3TRBJ2-5H2-IAbMHCIIMusMusculusCAASYNYAQGLTFCASRDWGGRQDTQYF1
DIYKGVYQFKSV_CAAQTGNYKYVFCASGDAGTGQDTQYFDIYKGVYQFKSVCAAQTGNYKYVFCASGDAGTGQDTQYFTRAV14D-3-DV8TRAJ40TRBV12-2+TRBV13-2TRBJ2-5H2-IAbMHCIIMusMusculusCAAQTGNYKYVFCASGDAGTGQDTQYF1
DIYKGVYQFKSV_CAASLTGGYKVVFCAWRTDNQDTQYFDIYKGVYQFKSVCAASLTGGYKVVFCAWRTDNQDTQYFTRAV14N-3TRAJ12TRBV31TRBJ2-5H2-IAbMHCIIMusMusculusCAASLTGGYKVVFCAWRTDNQDTQYF1
.......................................
GMGPLLATV_CAVLNNARLMFCASSVDRVADTQYFGMGPLLATVCAVLNNARLMFCASSVDRVADTQYFTRAV12-2NaNTRBV27NaNHLA-A*02:01MHCIHomoSapiensCAVLNNARLMFCASSVDRVADTQYF1
GMGPLLATV_CATDNDMRFCASSFGPDEQYFGMGPLLATVCATDNDMRFCASSFGPDEQYFNaNNaNTRBV13NaNHLA-A*02:01MHCIHomoSapiensCATDNDMRFCASSFGPDEQYF1
GMGPLLATV_CAYRSFNNNDMRFCASRSRGGHSPLHFGMGPLLATVCAYRSFNNNDMRFCASRSRGGHSPLHFNaNNaNTRBV13NaNHLA-A*02:01MHCIHomoSapiensCAYRSFNNNDMRFCASRSRGGHSPLHF1
GMGPLLATV_CAMTSFQKLVFCASSLRGEKNNYGYTFGMGPLLATVCAMTSFQKLVFCASSLRGEKNNYGYTFTRAV39NaNTRBV13NaNHLA-A*02:01MHCIHomoSapiensCAMTSFQKLVFCASSLRGEKNNYGYTF1
GMGPLLATV_CAVLNNARLMFCASSVDRVADTQYFGMGPLLATVCAVLNNARLMFCASSVDRVADTQYFNaNNaNTRBV27NaNHLA-A*02:01MHCIHomoSapiensCAVLNNARLMFCASSVDRVADTQYF1
\n", + "

17676 rows × 12 columns

\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + " \n", + " \n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "variable_name": "combined_df", + "summary": "{\n \"name\": \"combined_df\",\n \"rows\": 17676,\n \"fields\": [\n {\n \"column\": \"epitope\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 144,\n \"samples\": [\n \"LLEFYLAMPFATP\",\n \"LTDEMIAQY\",\n \"FLYNLLTRV\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"cdr3_TRA\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 9594,\n \"samples\": [\n \"CAVGANDYKLSF\",\n \"CALSEADSWGKLQF\",\n \"CAVDNARLMF\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"cdr3_TRB\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 10764,\n \"samples\": [\n \"CASGETSQNTLYF\",\n \"CAISDSAAGNNEQFF\",\n \"CASSLAGGDGGSTDTQYF\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"TRAV\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 244,\n \"samples\": [\n \"TRAV7D-5\",\n \"TRAV10\",\n \"TRAV6-4*03\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"TRAJ\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 90,\n \"samples\": [\n \"TRAJ2\",\n \"TRAJ30\",\n \"TRAJ24*02\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"TRBV\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 142,\n \"samples\": [\n \"TRBV8-1\",\n \"TRBV26\",\n \"TRBV6-04\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"TRBJ\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 32,\n \"samples\": [\n \"TCRBJ1-1\",\n \"TCRBJ1-5\",\n \"TRBJ5-6\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"MHC\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 36,\n \"samples\": [\n \"HLA-DRA:01\",\n \"HLA-A*08:01\",\n \"HLA-DRB1*04:05\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"MHC_class\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"MHCI\",\n \"MHCII\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"species\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"HomoSapiens\",\n \"MusMusculus\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"cdr3\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 14106,\n \"samples\": [\n \"CAGSNTNAGKSTFCASSIRSSYEQYF\",\n \"CTSVLLANYGNEKITFCAWSPISDYTF\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"value\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 1,\n \"max\": 1,\n \"num_unique_values\": 1,\n \"samples\": [\n 1\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" + } + }, + "metadata": {}, + "execution_count": 3 + } + ], + "source": [ + "################################# TCR-Epitope Binding Affinity Prediction Task #################################\n", + "os.chdir(path)\n", + "combined_df = pd.read_csv(\"MixTCRpred/full_training_set_146pmhc.csv\")\n", + "\n", + "# combine cdr3\n", + "combined_df[\"cdr3\"] = combined_df[\"cdr3_TRA\"] + combined_df[\"cdr3_TRB\"]\n", + "combined_df[\"value\"] = 1\n", + "combined_df.index = combined_df[\"epitope\"] + \"_\" + combined_df[\"cdr3\"]\n", + "combined_df\n", + "\n", + "############################################ choose ESM model ############################################\n", + "\n", + "# \"esm3-small-2024-08\" \"esm2_t6_8M_UR50D\"\n", + "model_name = \"esm3-small-2024-08\"\n", + "epitope_embeddings = pd.read_csv('MixTCRpred/data/epitope_embeddings_'+model_name+'.csv',index_col=0)\n", + "cdr3_embeddings = pd.read_csv('MixTCRpred/data/cdr3_embeddings_'+model_name+'.csv',index_col=0)\n", + "\n", + "epitope_embeddings.columns = \"epitope_\" + epitope_embeddings.columns\n", + "cdr3_embeddings.columns = \"cdr3_\" + cdr3_embeddings.columns\n", + "\n", + "############################## subset of available embeddings ##############################\n", + "combined_df = combined_df.loc[combined_df[\"epitope\"].isin(epitope_embeddings.index) , :]\n", + "combined_df = combined_df.loc[combined_df[\"cdr3\"].isin(cdr3_embeddings.index) , :]\n", + "combined_df" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 510 + }, + "id": "8cmFLvsb4mnT", + "outputId": "e15e20b4-d66b-4d22-da6b-1838b52faa11" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " MHC_H2-Db MHC_H2-IAb MHC_H2-IEk MHC_H2-Kb MHC_H2-Kd \\\n", + "epitope \n", + "DIYKGVYQFKSV 0.0 1.0 0.0 0.0 0.0 \n", + "GILGFVFTL 0.0 0.0 0.0 0.0 0.0 \n", + "SSYRRPVGI 0.0 0.0 0.0 1.0 0.0 \n", + "SSLENFRAYV 1.0 0.0 0.0 0.0 0.0 \n", + "LLWNGPMAV 0.0 0.0 0.0 0.0 0.0 \n", + "... ... ... ... ... ... \n", + "PQPELPYPQPE 0.0 0.0 0.0 0.0 0.0 \n", + "PGVLLKEFTVSGNIL 0.0 0.0 0.0 0.0 0.0 \n", + "LLLEWLAMA 0.0 0.0 0.0 0.0 0.0 \n", + "KGYVYQGL 0.0 0.0 0.0 1.0 0.0 \n", + "GMGPLLATV 0.0 0.0 0.0 0.0 0.0 \n", + "\n", + " MHC_H2-Ld MHC_HLA-A*02:01 MHC_HLA-A*08:01 MHC_HLA-A*11:01 \\\n", + "epitope \n", + "DIYKGVYQFKSV 0.0 0.0 0.0 0.0 \n", + "GILGFVFTL 0.0 1.0 0.0 0.0 \n", + "SSYRRPVGI 0.0 0.0 0.0 0.0 \n", + "SSLENFRAYV 0.0 0.0 0.0 0.0 \n", + "LLWNGPMAV 0.0 1.0 0.0 0.0 \n", + "... ... ... ... ... \n", + "PQPELPYPQPE 0.0 0.0 0.0 0.0 \n", + "PGVLLKEFTVSGNIL 0.0 0.0 0.0 0.0 \n", + "LLLEWLAMA 0.0 1.0 0.0 0.0 \n", + "KGYVYQGL 0.0 0.0 0.0 0.0 \n", + "GMGPLLATV 0.0 1.0 0.0 0.0 \n", + "\n", + " MHC_HLA-A*24:02 ... MHC_HLA-DQA1:02/DQB1*06:02 \\\n", + "epitope ... \n", + "DIYKGVYQFKSV 0.0 ... 0.0 \n", + "GILGFVFTL 0.0 ... 0.0 \n", + "SSYRRPVGI 0.0 ... 0.0 \n", + "SSLENFRAYV 0.0 ... 0.0 \n", + "LLWNGPMAV 0.0 ... 0.0 \n", + "... ... ... ... \n", + "PQPELPYPQPE 0.0 ... 0.0 \n", + "PGVLLKEFTVSGNIL 0.0 ... 0.0 \n", + "LLLEWLAMA 0.0 ... 0.0 \n", + "KGYVYQGL 0.0 ... 0.0 \n", + "GMGPLLATV 0.0 ... 0.0 \n", + "\n", + " MHC_HLA-DRA:01 MHC_HLA-DRA:01/DRB1:01 MHC_HLA-DRB1*04:01 \\\n", + "epitope \n", + "DIYKGVYQFKSV 0.0 0.0 0.0 \n", + "GILGFVFTL 0.0 0.0 0.0 \n", + "SSYRRPVGI 0.0 0.0 0.0 \n", + "SSLENFRAYV 0.0 0.0 0.0 \n", + "LLWNGPMAV 0.0 0.0 0.0 \n", + "... ... ... ... \n", + "PQPELPYPQPE 0.0 0.0 0.0 \n", + "PGVLLKEFTVSGNIL 1.0 0.0 0.0 \n", + "LLLEWLAMA 0.0 0.0 0.0 \n", + "KGYVYQGL 0.0 0.0 0.0 \n", + "GMGPLLATV 0.0 0.0 0.0 \n", + "\n", + " MHC_HLA-DRB1*04:05 MHC_HLA-DRB1*07:01 MHC_HLA-DRB1*11:01 \\\n", + "epitope \n", + "DIYKGVYQFKSV 0.0 0.0 0.0 \n", + "GILGFVFTL 0.0 0.0 0.0 \n", + "SSYRRPVGI 0.0 0.0 0.0 \n", + "SSLENFRAYV 0.0 0.0 0.0 \n", + "LLWNGPMAV 0.0 0.0 0.0 \n", + "... ... ... ... \n", + "PQPELPYPQPE 0.0 0.0 0.0 \n", + "PGVLLKEFTVSGNIL 0.0 0.0 0.0 \n", + "LLLEWLAMA 0.0 0.0 0.0 \n", + "KGYVYQGL 0.0 0.0 0.0 \n", + "GMGPLLATV 0.0 0.0 0.0 \n", + "\n", + " MHC_HLA-DRB1:01 MHC_class_MHCII species_MusMusculus \n", + "epitope \n", + "DIYKGVYQFKSV 0.0 1.0 1.0 \n", + "GILGFVFTL 0.0 0.0 0.0 \n", + "SSYRRPVGI 0.0 0.0 1.0 \n", + "SSLENFRAYV 0.0 0.0 1.0 \n", + "LLWNGPMAV 0.0 0.0 0.0 \n", + "... ... ... ... \n", + "PQPELPYPQPE 0.0 1.0 0.0 \n", + "PGVLLKEFTVSGNIL 0.0 1.0 0.0 \n", + "LLLEWLAMA 0.0 0.0 0.0 \n", + "KGYVYQGL 0.0 1.0 1.0 \n", + "GMGPLLATV 0.0 0.0 0.0 \n", + "\n", + "[144 rows x 37 columns]" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
MHC_H2-DbMHC_H2-IAbMHC_H2-IEkMHC_H2-KbMHC_H2-KdMHC_H2-LdMHC_HLA-A*02:01MHC_HLA-A*08:01MHC_HLA-A*11:01MHC_HLA-A*24:02...MHC_HLA-DQA1:02/DQB1*06:02MHC_HLA-DRA:01MHC_HLA-DRA:01/DRB1:01MHC_HLA-DRB1*04:01MHC_HLA-DRB1*04:05MHC_HLA-DRB1*07:01MHC_HLA-DRB1*11:01MHC_HLA-DRB1:01MHC_class_MHCIIspecies_MusMusculus
epitope
DIYKGVYQFKSV0.01.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.01.01.0
GILGFVFTL0.00.00.00.00.00.01.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
SSYRRPVGI0.00.00.01.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.01.0
SSLENFRAYV1.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.01.0
LLWNGPMAV0.00.00.00.00.00.01.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
..................................................................
PQPELPYPQPE0.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.01.00.0
PGVLLKEFTVSGNIL0.00.00.00.00.00.00.00.00.00.0...0.01.00.00.00.00.00.00.01.00.0
LLLEWLAMA0.00.00.00.00.00.01.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
KGYVYQGL0.00.00.01.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.01.01.0
GMGPLLATV0.00.00.00.00.00.01.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
\n", + "

144 rows × 37 columns

\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + " \n", + " \n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "variable_name": "df_encoded_epitope" + } + }, + "metadata": {}, + "execution_count": 4 + } + ], + "source": [ + "#################### encode additional information for the epitopes ####################\n", + "from sklearn.preprocessing import OneHotEncoder\n", + "\n", + "# One hot encoding of categorical variables\n", + "columns_to_encode = ['MHC','MHC_class','species']\n", + "df = combined_df.loc[:,columns_to_encode]\n", + "\n", + "one_hot_encoder = OneHotEncoder(drop='first', sparse_output=False)\n", + "# Fit and transform the data\n", + "one_hot_encoded = one_hot_encoder.fit_transform(df)\n", + "feature_names = one_hot_encoder.get_feature_names_out(df.columns)\n", + "df_encoded = pd.DataFrame(one_hot_encoded, columns=feature_names)\n", + "\n", + "df_encoded.index = combined_df[\"epitope\"]\n", + "df_encoded_epitope = df_encoded[~df_encoded.index.duplicated(keep=\"first\")]\n", + "df_encoded_epitope" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 493 + }, + "id": "UrSmETAb3b3W", + "outputId": "8dff008f-0981-48d1-c1cb-6169a5da6852" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " TRAV_TCRAV12-1 TRAV_TCRAV17 TRAV_TCRAV19 \\\n", + "cdr3 \n", + "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 0.0 \n", + "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 0.0 \n", + "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 0.0 \n", + "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 0.0 \n", + "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 0.0 \n", + "... ... ... ... \n", + "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 0.0 \n", + "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 0.0 \n", + "CATDNDMRFCASSFGPDEQYF 0.0 0.0 0.0 \n", + "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 0.0 \n", + "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 0.0 \n", + "\n", + " TRAV_TCRAV21 TRAV_TCRAV23/DV6 TRAV_TCRAV3 \\\n", + "cdr3 \n", + "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 0.0 \n", + "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 0.0 \n", + "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 0.0 \n", + "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 0.0 \n", + "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 0.0 \n", + "... ... ... ... \n", + "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 0.0 \n", + "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 0.0 \n", + "CATDNDMRFCASSFGPDEQYF 0.0 0.0 0.0 \n", + "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 0.0 \n", + "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 0.0 \n", + "\n", + " TRAV_TCRAV38-1 TRAV_TCRAV38-2/DV8 \\\n", + "cdr3 \n", + "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 \n", + "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 \n", + "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 \n", + "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 \n", + "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 \n", + "... ... ... \n", + "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 \n", + "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 \n", + "CATDNDMRFCASSFGPDEQYF 0.0 0.0 \n", + "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 \n", + "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 \n", + "\n", + " TRAV_TCRAV41 TRAV_TRAV-2 ... TRBJ_TRBJ2-5 \\\n", + "cdr3 ... \n", + "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 ... 0.0 \n", + "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 ... 0.0 \n", + "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 ... 1.0 \n", + "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 ... 1.0 \n", + "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 ... 1.0 \n", + "... ... ... ... ... \n", + "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 ... 0.0 \n", + "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 ... 0.0 \n", + "CATDNDMRFCASSFGPDEQYF 0.0 0.0 ... 0.0 \n", + "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 ... 0.0 \n", + "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 ... 0.0 \n", + "\n", + " TRBJ_TRBJ2-6 TRBJ_TRBJ2-7 TRBJ_TRBJ2-7 \\\n", + "cdr3 \n", + "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 0.0 \n", + "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 0.0 \n", + "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 0.0 \n", + "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 0.0 \n", + "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 0.0 \n", + "... ... ... ... \n", + "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 0.0 \n", + "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 0.0 \n", + "CATDNDMRFCASSFGPDEQYF 0.0 0.0 0.0 \n", + "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 0.0 \n", + "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 0.0 \n", + "\n", + " TRBJ_TRBJ20-1 TRBJ_TRBJ24-1 \\\n", + "cdr3 \n", + "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 \n", + "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 \n", + "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 \n", + "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 \n", + "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 \n", + "... ... ... \n", + "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 \n", + "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 \n", + "CATDNDMRFCASSFGPDEQYF 0.0 0.0 \n", + "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 \n", + "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 \n", + "\n", + " TRBJ_TRBJ38-2/DV8 TRBJ_TRBJ5-1 TRBJ_TRBJ5-6 \\\n", + "cdr3 \n", + "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 0.0 \n", + "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 0.0 \n", + "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 0.0 \n", + "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 0.0 \n", + "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 0.0 \n", + "... ... ... ... \n", + "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 0.0 \n", + "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 0.0 \n", + "CATDNDMRFCASSFGPDEQYF 0.0 0.0 0.0 \n", + "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 0.0 \n", + "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 0.0 \n", + "\n", + " TRBJ_nan \n", + "cdr3 \n", + "CAGGADRLTFCASSPAGNTLYF 0.0 \n", + "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 \n", + "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 \n", + "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 \n", + "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 \n", + "... ... \n", + "CAYRSGEYGNKLVFCASSMAGSSYEQYF 1.0 \n", + "CAYRSFNNNDMRFCASRSRGGHSPLHF 1.0 \n", + "CATDNDMRFCASSFGPDEQYF 1.0 \n", + "CAVLNNARLMFCASSVDRVADTQYF 1.0 \n", + "CAMTSFQKLVFCASSLRGEKNNYGYTF 1.0 \n", + "\n", + "[14106 rows x 508 columns]" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
TRAV_TCRAV12-1TRAV_TCRAV17TRAV_TCRAV19TRAV_TCRAV21TRAV_TCRAV23/DV6TRAV_TCRAV3TRAV_TCRAV38-1TRAV_TCRAV38-2/DV8TRAV_TCRAV41TRAV_TRAV-2...TRBJ_TRBJ2-5TRBJ_TRBJ2-6TRBJ_TRBJ2-7TRBJ_TRBJ2-7TRBJ_TRBJ20-1TRBJ_TRBJ24-1TRBJ_TRBJ38-2/DV8TRBJ_TRBJ5-1TRBJ_TRBJ5-6TRBJ_nan
cdr3
CAGGADRLTFCASSPAGNTLYF0.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
CAASGGSNYNVLYFCAWSLWGGPSAETLYF0.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
CAASYNYAQGLTFCASRDWGGRQDTQYF0.00.00.00.00.00.00.00.00.00.0...1.00.00.00.00.00.00.00.00.00.0
CAAQTGNYKYVFCASGDAGTGQDTQYF0.00.00.00.00.00.00.00.00.00.0...1.00.00.00.00.00.00.00.00.00.0
CAASLTGGYKVVFCAWRTDNQDTQYF0.00.00.00.00.00.00.00.00.00.0...1.00.00.00.00.00.00.00.00.00.0
..................................................................
CAYRSGEYGNKLVFCASSMAGSSYEQYF0.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.01.0
CAYRSFNNNDMRFCASRSRGGHSPLHF0.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.01.0
CATDNDMRFCASSFGPDEQYF0.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.01.0
CAVLNNARLMFCASSVDRVADTQYF0.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.01.0
CAMTSFQKLVFCASSLRGEKNNYGYTF0.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.01.0
\n", + "

14106 rows × 508 columns

\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + " \n", + " \n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "variable_name": "df_encoded_TCR" + } + }, + "metadata": {}, + "execution_count": 5 + } + ], + "source": [ + "##################### encode additional information for the TCRs #####################\n", + "from sklearn.preprocessing import OneHotEncoder\n", + "\n", + "# One hot encoding of categorical variables\n", + "columns_to_encode = ['TRAV','TRAJ','TRBV','TRBJ']\n", + "df = combined_df.loc[:,columns_to_encode]\n", + "\n", + "one_hot_encoder = OneHotEncoder(drop='first', sparse_output=False)\n", + "# Fit and transform the data\n", + "one_hot_encoded = one_hot_encoder.fit_transform(df)\n", + "feature_names = one_hot_encoder.get_feature_names_out(df.columns)\n", + "df_encoded = pd.DataFrame(one_hot_encoded, columns=feature_names)\n", + "\n", + "df_encoded.index = combined_df[\"cdr3\"]\n", + "df_encoded_TCR = df_encoded[~df_encoded.index.duplicated(keep=\"first\")]\n", + "df_encoded_TCR" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "vUdehGAmrGra" + }, + "outputs": [], + "source": [ + "###################################### enter parameters ######################################\n", + "# Setting_new_epitope_new_TCR_LOOCV\n", + "\n", + "setting = \"Setting_new_epitope_new_TCR_LOOCV\"\n", + "\n", + "# \"ESM3 + VJ genes\" \"all features\" \"ESMonly\" \"withoutESM\"\n", + "features_name = \"ESM3 + VJ genes\"\n", + "\n", + "# MHCI MHCII all\n", + "MHC_class = \"all\"\n", + "species = \"all\" # HomoSapiens all\n", + "tcr_features = \"cdr3\"\n", + "N_TCRs = 2 # epitopes with at least N TCRs\n", + "ratio = 5 # neg / pos ratio\n", + "repetition = 5\n", + "algorithm = \"sklearn_logit\"\n", + "result_folder = \"MixTCRpred/output/\"+setting+\"/\"\n", + "nfolds = 5 # here it is only for Gridsearch\n", + "n_jobs = -1\n", + "\n", + "os.chdir(path)\n", + "os.makedirs(result_folder,exist_ok=True)\n", + "os.chdir(result_folder)\n", + "" + ] + }, + { + "cell_type": "code", + "source": [ + "combined_df.MHC_class.value_counts()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 198 + }, + "id": "C9NE89S-N92J", + "outputId": "256e76eb-0fe6-40c0-ddf6-6b72876ec604" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "MHC_class\n", + "MHCI 13248\n", + "MHCII 4428\n", + "Name: count, dtype: int64" + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
count
MHC_class
MHCI13248
MHCII4428
\n", + "

" + ] + }, + "metadata": {}, + "execution_count": 9 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# MHC, species\n", + "if species != \"all\":\n", + " combined_df = combined_df.loc[combined_df[\"species\"]==species,:]\n", + "if MHC_class != \"all\":\n", + " combined_df = combined_df.loc[combined_df[\"MHC_class\"]==MHC_class, : ]\n", + "\n", + "# epitopes with at least N TCRs\n", + "mat = combined_df.pivot_table(index=tcr_features, columns='epitope', values='value', aggfunc='max')\n", + "mat = mat.loc[:, (mat.notna().sum() >= N_TCRs)]\n", + "\n", + "from joblib import Parallel, delayed\n", + "n_jobs = 32 # 32 12\n", + "results = Parallel(n_jobs=n_jobs)(delayed(run_ML)(epitope,epitope_embeddings,cdr3_embeddings) for epitope in mat.columns)\n", + "\n", + "# if results are dfs\n", + "probability_all = pd.concat([result[0] for result in results])\n", + "importance_all = pd.concat([result[1] for result in results])" + ], + "metadata": { + "id": "NAl8oDMV-9MB" + }, + "execution_count": 12, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [], + "metadata": { + "id": "7EurSAVxIbjF" + } + }, + { + "cell_type": "code", + "source": [ + "probability_all" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 424 + }, + "id": "9rt3K-0aGWX5", + "outputId": "f3e0f3cb-fce6-4c9d-cb70-e17db9859ac5" + }, + "execution_count": 18, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " split epitope iteration \\\n", + "0 test ALSKGVHFV 1 \n", + "1 test ALSKGVHFV 1 \n", + "2 test ALSKGVHFV 1 \n", + "3 test ALSKGVHFV 1 \n", + "4 test ALSKGVHFV 1 \n", + "... ... ... ... \n", + "83335 train ALSKGVHFV 1 \n", + "83336 train ALSKGVHFV 1 \n", + "83337 train ALSKGVHFV 1 \n", + "83338 train ALSKGVHFV 1 \n", + "83339 train ALSKGVHFV 1 \n", + "\n", + " sample predicted_prob RealClass \n", + "0 ALSKGVHFV_CAVEDGQKLLFCASSPGGTATYEQYF 0.998830 1.0 \n", + "1 ALSKGVHFV_CALGSDSWGKLQFCASSLAGDSYNEQFF 0.877127 1.0 \n", + "2 ALSKGVHFV_CALSEGRDDKIIFCASSIVPWDTQYF 0.411237 1.0 \n", + "3 ALSKGVHFV_CAVAPFGNEKLTFCASSTQSTVNIQYF 0.700061 1.0 \n", + "4 ALSKGVHFV_CAMRGRTGNQFYFCASSQKLAGDNEQFF 0.884197 1.0 \n", + "... ... ... ... \n", + "83335 CTELKLSDY_CAVTTDSWGKLQFCASSRQPMNTEAFF 0.460549 0.0 \n", + "83336 CTELKLSDY_CAAKEGYSTLTFCASSEGDRVTEAFF 0.230751 0.0 \n", + "83337 CTELKLSDY_CAVVYPLTHGSSNTGKLIFCASSLEGQLNEQFF 0.478130 0.0 \n", + "83338 CTELKLSDY_CAAEAGAGNKLTFCASGDSANSDYTF 0.083402 0.0 \n", + "83339 CTELKLSDY_CAMRVSGGSNAKLTFCASRGGANTGQLYF 0.241574 0.0 \n", + "\n", + "[83340 rows x 6 columns]" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
splitepitopeiterationsamplepredicted_probRealClass
0testALSKGVHFV1ALSKGVHFV_CAVEDGQKLLFCASSPGGTATYEQYF0.9988301.0
1testALSKGVHFV1ALSKGVHFV_CALGSDSWGKLQFCASSLAGDSYNEQFF0.8771271.0
2testALSKGVHFV1ALSKGVHFV_CALSEGRDDKIIFCASSIVPWDTQYF0.4112371.0
3testALSKGVHFV1ALSKGVHFV_CAVAPFGNEKLTFCASSTQSTVNIQYF0.7000611.0
4testALSKGVHFV1ALSKGVHFV_CAMRGRTGNQFYFCASSQKLAGDNEQFF0.8841971.0
.....................
83335trainALSKGVHFV1CTELKLSDY_CAVTTDSWGKLQFCASSRQPMNTEAFF0.4605490.0
83336trainALSKGVHFV1CTELKLSDY_CAAKEGYSTLTFCASSEGDRVTEAFF0.2307510.0
83337trainALSKGVHFV1CTELKLSDY_CAVVYPLTHGSSNTGKLIFCASSLEGQLNEQFF0.4781300.0
83338trainALSKGVHFV1CTELKLSDY_CAAEAGAGNKLTFCASGDSANSDYTF0.0834020.0
83339trainALSKGVHFV1CTELKLSDY_CAMRVSGGSNAKLTFCASRGGANTGQLYF0.2415740.0
\n", + "

83340 rows × 6 columns

\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + " \n", + " \n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "variable_name": "probability_all", + "summary": "{\n \"name\": \"probability_all\",\n \"rows\": 83340,\n \"fields\": [\n {\n \"column\": \"split\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"train\",\n \"test\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"epitope\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 1,\n \"samples\": [\n \"ALSKGVHFV\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"iteration\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 1,\n \"max\": 1,\n \"num_unique_values\": 1,\n \"samples\": [\n 1\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"sample\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 83340,\n \"samples\": [\n \"NYNYLYRLF_CAARPEPTSTGTALIFCASSEGQGYEQYF\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"predicted_prob\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.3198026349121069,\n \"min\": 0.0001305711276764333,\n \"max\": 0.9999985774716064,\n \"num_unique_values\": 83340,\n \"samples\": [\n 0.3508570053190846\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"RealClass\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.37445226353599625,\n \"min\": 0.0,\n \"max\": 1.0,\n \"num_unique_values\": 2,\n \"samples\": [\n 0.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" + } + }, + "metadata": {}, + "execution_count": 18 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# save\n", + "probability_all.to_csv(\"probability_\"+features_name+\"_\"+algorithm+\"_N_TCRs\"+str(N_TCRs)+\"_species_\"+species+\"_MHC_\"+MHC_class+\".csv\")\n", + "importance_all.to_csv(\"importance_\"+features_name+\"_\"+algorithm+\"_N_TCRs\"+str(N_TCRs)+\"_species_\"+species+\"_MHC_\"+MHC_class+\".csv\")\n" + ], + "metadata": { + "id": "oJ0ysrzz_nfd" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#probability_all = pd.read_csv(\"probability_\"+features_name+\"_\"+algorithm+\"_N_TCRs\"+str(N_TCRs)+\"_species_\"+species+\"_MHC_\"+MHC_class+\".csv\", index_col=0)\n", + "#importance_all = pd.read_csv(\"importance_\"+features_name+\"_\"+algorithm+\"_N_TCRs\"+str(N_TCRs)+\"_species_\"+species+\"_MHC_\"+MHC_class+\".csv\", index_col=0)\n" + ], + "metadata": { + "id": "7u0xKg7nRwB-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import seaborn as sns\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Compute, save and plot aggregated feature importance across runs\n", + "importance_agg = aggregate_feature_importance(imp=importance_all, f_name=features_name)\n", + "# Plot feature importance\n", + "importance_agg['abs_importance'] = importance_agg['importance'].abs()\n", + "importance_agg_sorted = importance_agg.sort_values(by='abs_importance',ascending=False).drop(columns=['abs_importance'])\n", + "\n", + "top_30 = importance_agg_sorted.head(30)\n", + "plt.figure(figsize=(7, 7))\n", + "plt.barh(top_30.feature_name, top_30.importance, color='blue')\n", + "plt.xlabel('Feature Importance')\n", + "plt.ylabel('Feature Name')\n", + "plt.title('Top 30 features')\n", + "plt.gca().invert_yaxis()\n", + "plt.tight_layout()\n", + "os.makedirs('fig', exist_ok=True)\n", + "plt.savefig(\"fig/importance_\"+algorithm+\"_\"+features_name+\"_species_\"+species+\"_MHC_\"+MHC_class+\".pdf\")" + ], + "metadata": { + "id": "BWgIagpeKLDZ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "D6lYfPZFr4Te" + }, + "outputs": [], + "source": [ + "################################## plot AUC by epitope ##################################\n", + "\n", + "split = \"test\"\n", + "probability = probability_all.loc[probability_all.split==split,:]\n", + "\n", + "auc_values = plot_ROC_curve(probability)\n", + "\n", + "# Save performance\n", + "result = pd.DataFrame({'algorithm':[algorithm],'median_auc': [np.median(auc_values)] })\n", + "result.to_csv(\"auc_values_epitope_specific_mean_\"+features_name+\"_\"+algorithm+\"_N_TCRs\"+str(N_TCRs)+\"_\"+split+\"_species_\"+species+\"_MHC_\"+MHC_class+\".csv\")\n", + "auc_values.to_csv(\"auc_values_epitope_specific_\"+features_name+\"_\"+algorithm+\"_N_TCRs\"+str(N_TCRs)+\"_\"+split+\"_species_\"+species+\"_MHC_\"+MHC_class+\".csv\")" + ] + }, + { + "cell_type": "code", + "source": [ + "auc_values" + ], + "metadata": { + "id": "N_xEef-4RLt8" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "################################## plot AUC by epitope ##################################\n", + "\n", + "split = \"train\"\n", + "probability = probability_all.loc[probability_all.split==split,:]\n", + "\n", + "auc_values = plot_ROC_curve(probability)\n", + "\n", + "# Save performance\n", + "result = pd.DataFrame({'algorithm':[algorithm],'median_auc': [np.median(auc_values)] })\n", + "result.to_csv(\"auc_values_epitope_specific_mean_\"+features_name+\"_\"+algorithm+\"_N_TCRs\"+str(N_TCRs)+\"_\"+split+\"_species_\"+species+\"_MHC_\"+MHC_class+\".csv\")\n", + "auc_values.to_csv(\"auc_values_epitope_specific_\"+features_name+\"_\"+algorithm+\"_N_TCRs\"+str(N_TCRs)+\"_\"+split+\"_species_\"+species+\"_MHC_\"+MHC_class+\".csv\")\n" + ], + "metadata": { + "id": "0qib6rM0q4jl" + }, + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "colab": { + "provenance": [], + "machine_shape": "hm" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file