Skip to content

Commit

Permalink
Generating plots for XGB model
Browse files Browse the repository at this point in the history
  • Loading branch information
JulianKlug committed Mar 28, 2023
1 parent 2ad82f9 commit d1c8b92
Show file tree
Hide file tree
Showing 4 changed files with 739 additions and 14 deletions.
54 changes: 44 additions & 10 deletions prediction/figures/roc_and_pr_curve_figure/roc_and_pr_curve.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pickle\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
Expand All @@ -32,10 +33,19 @@
"metadata": {},
"outputs": [],
"source": [
"lstm_bs_predictions_path = '/Users/jk1/temp/opsum_prediction_output/LSTM_72h_testing/3M_mRS01/2023_01_06_1847/test_LSTM_sigmoid_all_unchanged_0.0_2_True_RMSprop_3M mRS 0-1_128_3/bootstrapped_gt_and_pred.pkl'\n",
"lstm_bs_predictions_path = '/Users/jk1/temp/opsum_prediction_output/LSTM_72h_testing/3M_mRS02/2023_01_02_1057/test_LSTM_sigmoid_all_balanced_0.2_2_True_RMSprop_3M mRS 0-2_16_3/bootstrapped_gt_and_pred.pkl'\n",
"thrive_c_bs_predictions_path = '/Users/jk1/temp/opsum_prediction_output/THRIVE_C/THRIVE_C_predictions/bootstrapped_gt_and_pred.pkl'\n",
"xgb_bs_predictions_path = '/Users/jk1/temp/opsum_prediction_output/linear_72h_xgb/with_feature_aggregration/testing/bootstrapped_gt_and_pred.pkl'\n",
"# outcome = '3M mRS 0-1'"
"outcome = '3M mRS 0-2'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"output_dir = '/Users/jk1/Downloads'"
]
},
{
Expand Down Expand Up @@ -148,8 +158,8 @@
"ax = sns.lineplot(data=thrivec_resampled_roc_df, x='fpr', y='tpr', color=all_colors_palette[1], label='THRIVE-C (area = %0.2f)' % np.median(thrivec_bs_roc_aucs),\n",
" ax=ax, errorbar='sd')\n",
"\n",
"ax = sns.lineplot(data=xgb_resampled_roc_df, x='fpr', y='tpr', color=all_colors_palette[2], label='XGBoost (area = %0.2f)' % np.median(xgb_bs_roc_aucs),\n",
" ax=ax, errorbar='sd')\n",
"# ax = sns.lineplot(data=xgb_resampled_roc_df, x='fpr', y='tpr', color=all_colors_palette[2], label='XGBoost (area = %0.2f)' % np.median(xgb_bs_roc_aucs),\n",
"# ax=ax, errorbar='sd')\n",
"\n",
"ax.plot([0, 1], [0, 1], color='grey', lw=1, linestyle='--', alpha=0.5)\n",
"\n",
Expand All @@ -162,8 +172,9 @@
" legend_markers, legend_labels = ax.get_legend_handles_labels()\n",
" sd1_patch = mpatches.Patch(color=all_colors_palette[0], alpha=0.3)\n",
" sd2_patch = mpatches.Patch(color=all_colors_palette[1], alpha=0.3)\n",
" sd3_patch = mpatches.Patch(color=all_colors_palette[2], alpha=0.3)\n",
" sd_marker = (sd1_patch, sd2_patch, sd3_patch)\n",
" # sd3_patch = mpatches.Patch(color=all_colors_palette[2], alpha=0.3)\n",
" sd_marker = (sd1_patch, sd2_patch)\n",
" # sd_marker = (sd1_patch, sd2_patch, sd3_patch)\n",
" sd_labels = '± s.d.'\n",
" legend_markers.append(sd_marker)\n",
" legend_labels.append(sd_labels)\n",
Expand All @@ -177,6 +188,17 @@
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# # save fig\n",
"# fig = ax.get_figure()\n",
"# fig.savefig(os.path.join(output_dir, 'roc_curve.svg'), bbox_inches='tight')"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down Expand Up @@ -208,8 +230,8 @@
"ax1 = sns.lineplot(data=thrivec_resampled_pr_df, x='recall', y='precision', color=all_colors_palette[1], label='THRIVE-C (area = %0.2f)' % np.median(thrivec_bs_pr_aucs),\n",
" ax=ax1, errorbar='sd')\n",
"\n",
"ax1 = sns.lineplot(data=xgb_resampled_pr_df, x='recall', y='precision', color=all_colors_palette[2], label='XGBoost (area = %0.2f)' % np.median(xgb_bs_pr_aucs),\n",
" ax=ax1, errorbar='sd')\n",
"# ax1 = sns.lineplot(data=xgb_resampled_pr_df, x='recall', y='precision', color=all_colors_palette[2], label='XGBoost (area = %0.2f)' % np.median(xgb_bs_pr_aucs),\n",
"# ax=ax1, errorbar='sd')\n",
"\n",
"\n",
"ax1.set_xlabel('Recall', fontsize=label_font_size)\n",
Expand All @@ -221,8 +243,9 @@
" legend_markers, legend_labels = ax1.get_legend_handles_labels()\n",
" sd1_patch = mpatches.Patch(color=all_colors_palette[0], alpha=0.3)\n",
" sd2_patch = mpatches.Patch(color=all_colors_palette[1], alpha=0.3)\n",
" sd3_patch = mpatches.Patch(color=all_colors_palette[2], alpha=0.3)\n",
" sd_marker = (sd1_patch, sd2_patch, sd3_patch)\n",
" # sd3_patch = mpatches.Patch(color=all_colors_palette[2], alpha=0.3)\n",
" sd_marker = (sd1_patch, sd2_patch)\n",
" # sd_marker = (sd1_patch, sd2_patch, sd3_patch)\n",
" sd_labels = '± s.d.'\n",
" legend_markers.append(sd_marker)\n",
" legend_labels.append(sd_labels)\n",
Expand All @@ -236,6 +259,17 @@
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# # save fig\n",
"# fig = ax1.get_figure()\n",
"# fig.savefig(os.path.join(output_dir, 'pr_curve.svg'), bbox_inches='tight')"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
Loading

0 comments on commit d1c8b92

Please sign in to comment.