From d1c8b92bf5989af3eeca942a3eaa2331128fc558 Mon Sep 17 00:00:00 2001 From: Julian Klug Date: Tue, 28 Mar 2023 22:34:56 +0200 Subject: [PATCH] Generating plots for XGB model --- .../roc_and_pr_curve.ipynb | 54 +- .../top_shap_features_figure.ipynb | 674 ++++++++++++++++++ .../hyperoptimisation_analysis.ipynb | 9 +- .../outcome_prediction/treeModel/test_xgb.py | 16 +- 4 files changed, 739 insertions(+), 14 deletions(-) create mode 100644 prediction/outcome_prediction/treeModel/figures/top_features_shap/top_shap_features_figure.ipynb diff --git a/prediction/figures/roc_and_pr_curve_figure/roc_and_pr_curve.ipynb b/prediction/figures/roc_and_pr_curve_figure/roc_and_pr_curve.ipynb index d85fb23..ab7cc84 100644 --- a/prediction/figures/roc_and_pr_curve_figure/roc_and_pr_curve.ipynb +++ b/prediction/figures/roc_and_pr_curve_figure/roc_and_pr_curve.ipynb @@ -15,6 +15,7 @@ "metadata": {}, "outputs": [], "source": [ + "import os\n", "import pickle\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", @@ -32,10 +33,19 @@ "metadata": {}, "outputs": [], "source": [ - "lstm_bs_predictions_path = '/Users/jk1/temp/opsum_prediction_output/LSTM_72h_testing/3M_mRS01/2023_01_06_1847/test_LSTM_sigmoid_all_unchanged_0.0_2_True_RMSprop_3M mRS 0-1_128_3/bootstrapped_gt_and_pred.pkl'\n", + "lstm_bs_predictions_path = '/Users/jk1/temp/opsum_prediction_output/LSTM_72h_testing/3M_mRS02/2023_01_02_1057/test_LSTM_sigmoid_all_balanced_0.2_2_True_RMSprop_3M mRS 0-2_16_3/bootstrapped_gt_and_pred.pkl'\n", "thrive_c_bs_predictions_path = '/Users/jk1/temp/opsum_prediction_output/THRIVE_C/THRIVE_C_predictions/bootstrapped_gt_and_pred.pkl'\n", "xgb_bs_predictions_path = '/Users/jk1/temp/opsum_prediction_output/linear_72h_xgb/with_feature_aggregration/testing/bootstrapped_gt_and_pred.pkl'\n", - "# outcome = '3M mRS 0-1'" + "outcome = '3M mRS 0-2'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_dir = '/Users/jk1/Downloads'" ] }, { @@ -148,8 +158,8 @@ "ax = sns.lineplot(data=thrivec_resampled_roc_df, x='fpr', y='tpr', color=all_colors_palette[1], label='THRIVE-C (area = %0.2f)' % np.median(thrivec_bs_roc_aucs),\n", " ax=ax, errorbar='sd')\n", "\n", - "ax = sns.lineplot(data=xgb_resampled_roc_df, x='fpr', y='tpr', color=all_colors_palette[2], label='XGBoost (area = %0.2f)' % np.median(xgb_bs_roc_aucs),\n", - " ax=ax, errorbar='sd')\n", + "# ax = sns.lineplot(data=xgb_resampled_roc_df, x='fpr', y='tpr', color=all_colors_palette[2], label='XGBoost (area = %0.2f)' % np.median(xgb_bs_roc_aucs),\n", + "# ax=ax, errorbar='sd')\n", "\n", "ax.plot([0, 1], [0, 1], color='grey', lw=1, linestyle='--', alpha=0.5)\n", "\n", @@ -162,8 +172,9 @@ " legend_markers, legend_labels = ax.get_legend_handles_labels()\n", " sd1_patch = mpatches.Patch(color=all_colors_palette[0], alpha=0.3)\n", " sd2_patch = mpatches.Patch(color=all_colors_palette[1], alpha=0.3)\n", - " sd3_patch = mpatches.Patch(color=all_colors_palette[2], alpha=0.3)\n", - " sd_marker = (sd1_patch, sd2_patch, sd3_patch)\n", + " # sd3_patch = mpatches.Patch(color=all_colors_palette[2], alpha=0.3)\n", + " sd_marker = (sd1_patch, sd2_patch)\n", + " # sd_marker = (sd1_patch, sd2_patch, sd3_patch)\n", " sd_labels = '± s.d.'\n", " legend_markers.append(sd_marker)\n", " legend_labels.append(sd_labels)\n", @@ -177,6 +188,17 @@ "plt.show()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# # save fig\n", + "# fig = ax.get_figure()\n", + "# fig.savefig(os.path.join(output_dir, 'roc_curve.svg'), bbox_inches='tight')" + ] + }, { "cell_type": "markdown", "metadata": { @@ -208,8 +230,8 @@ "ax1 = sns.lineplot(data=thrivec_resampled_pr_df, x='recall', y='precision', color=all_colors_palette[1], label='THRIVE-C (area = %0.2f)' % np.median(thrivec_bs_pr_aucs),\n", " ax=ax1, errorbar='sd')\n", "\n", - "ax1 = sns.lineplot(data=xgb_resampled_pr_df, x='recall', y='precision', color=all_colors_palette[2], label='XGBoost (area = %0.2f)' % np.median(xgb_bs_pr_aucs),\n", - " ax=ax1, errorbar='sd')\n", + "# ax1 = sns.lineplot(data=xgb_resampled_pr_df, x='recall', y='precision', color=all_colors_palette[2], label='XGBoost (area = %0.2f)' % np.median(xgb_bs_pr_aucs),\n", + "# ax=ax1, errorbar='sd')\n", "\n", "\n", "ax1.set_xlabel('Recall', fontsize=label_font_size)\n", @@ -221,8 +243,9 @@ " legend_markers, legend_labels = ax1.get_legend_handles_labels()\n", " sd1_patch = mpatches.Patch(color=all_colors_palette[0], alpha=0.3)\n", " sd2_patch = mpatches.Patch(color=all_colors_palette[1], alpha=0.3)\n", - " sd3_patch = mpatches.Patch(color=all_colors_palette[2], alpha=0.3)\n", - " sd_marker = (sd1_patch, sd2_patch, sd3_patch)\n", + " # sd3_patch = mpatches.Patch(color=all_colors_palette[2], alpha=0.3)\n", + " sd_marker = (sd1_patch, sd2_patch)\n", + " # sd_marker = (sd1_patch, sd2_patch, sd3_patch)\n", " sd_labels = '± s.d.'\n", " legend_markers.append(sd_marker)\n", " legend_labels.append(sd_labels)\n", @@ -236,6 +259,17 @@ "plt.show()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# # save fig\n", + "# fig = ax1.get_figure()\n", + "# fig.savefig(os.path.join(output_dir, 'pr_curve.svg'), bbox_inches='tight')" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/prediction/outcome_prediction/treeModel/figures/top_features_shap/top_shap_features_figure.ipynb b/prediction/outcome_prediction/treeModel/figures/top_features_shap/top_shap_features_figure.ipynb new file mode 100644 index 0000000..cc9c118 --- /dev/null +++ b/prediction/outcome_prediction/treeModel/figures/top_features_shap/top_shap_features_figure.ipynb @@ -0,0 +1,674 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import shap\n", + "import numpy as np\n", + "import pandas as pd\n", + "import os\n", + "import pickle\n", + "from prediction.outcome_prediction.LSTM.testing.shap_helper_functions import check_shap_version_compatibility\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Requirements:\n", + "- TensorFlow 1.14\n", + "- Python 3.7\n", + "- Protobuf downgrade to 3.20: `pip install protobuf==3.20`\n", + "- downgrade h5py to 2.10: `pip install h5py==2.10`\n", + "- turn off masking in LSTM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Shap values require very specific versions\n", + "check_shap_version_compatibility()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# print the JS visualization code to the notebook\n", + "shap.initjs()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "saved_model_path = '/Users/jk1/temp/opsum_prediction_output/linear_72h_xgb/with_feature_aggregration/testing/selected_xgb_model_cv3.pkl'\n", + "features_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_features_01012023_233050.csv'\n", + "labels_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_outcomes_01012023_233050.csv'\n", + "cat_encoding_path = os.path.join(os.path.dirname(features_path), f'logs_{os.path.basename(features_path).split(\".\")[0].split(\"_\")[-2]}_{os.path.basename(features_path).split(\".\")[0].split(\"_\")[-1]}/categorical_variable_encoding.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_dir = '/Users/jk1/Downloads'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "outcome = '3M mRS 0-2'\n", + "moving_time_average = False\n", + "test_size = 0.2\n", + "seed = 42" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "override_masking_value = False" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## Load the data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from prediction.utils.utils import aggregate_features_over_time\n", + "from prediction.outcome_prediction.data_loading.data_formatting import format_to_2d_table_with_time\n", + "from sklearn.model_selection import train_test_split\n", + "from prediction.outcome_prediction.data_loading.data_formatting import features_to_numpy, \\\n", + " link_patient_id_to_outcome, numpy_to_lookup_table\n", + "\n", + "X, y = format_to_2d_table_with_time(feature_df_path=features_path, outcome_df_path=labels_path,\n", + " outcome=outcome)\n", + "\n", + "\"\"\"\n", + "SPLITTING DATA\n", + "Splitting is done by patient id (and not admission id) as in case of the rare multiple admissions per patient there\n", + "would be a risk of data leakage otherwise split 'pid' in TRAIN and TEST pid = unique patient_id\n", + "\"\"\"\n", + "# Reduce every patient to a single outcome (to avoid duplicates)\n", + "all_pids_with_outcome = link_patient_id_to_outcome(y, outcome)\n", + "pid_train, pid_test, y_pid_train, y_pid_test = train_test_split(all_pids_with_outcome.patient_id.tolist(),\n", + " all_pids_with_outcome.outcome.tolist(),\n", + " stratify=all_pids_with_outcome.outcome.tolist(),\n", + " test_size=test_size,\n", + " random_state=seed)\n", + "\n", + "# Extracting TEST data\n", + "test_X_df = X[X.patient_id.isin(pid_test)]\n", + "test_y_df = y[y.patient_id.isin(pid_test)]\n", + "\n", + "test_X_np = features_to_numpy(test_X_df,\n", + " ['case_admission_id', 'relative_sample_date_hourly_cat', 'sample_label', 'value'])\n", + "test_y_np = np.array([test_y_df[test_y_df.case_admission_id == cid].outcome.values[0] for cid in\n", + " test_X_np[:, 0, 0, 0]]).astype('float32')\n", + "test_features_lookup_table = numpy_to_lookup_table(test_X_np)\n", + "\n", + "# Remove the case_admission_id, sample_label, and time_step_label columns from the data\n", + "test_X_np = test_X_np[:, :, :, -1].astype('float32')\n", + "X_test, y_test = aggregate_features_over_time(test_X_np, test_y_np, moving_average=moving_time_average)\n", + "# only keep prediction at last timepoint\n", + "X_test = X_test.reshape(-1, 72, X_test.shape[-1])[:, -1, :].astype('float32')\n", + "y_test = y_test.reshape(-1, 72)[:, -1].astype('float32')\n", + "\n", + "\n", + "# Extracting TRAIN data\n", + "# find indexes for train admissions\n", + "X_train_df = X.loc[X.patient_id.isin(pid_train)]\n", + "y_train_df = y.loc[y.patient_id.isin(pid_train)]\n", + "\n", + "# Transform dataframes to numpy arrays\n", + "X_train = features_to_numpy(X_train_df,\n", + " ['case_admission_id', 'relative_sample_date_hourly_cat', 'sample_label',\n", + " 'value'])\n", + "y_train = np.array([y_train_df[y_train_df.case_admission_id == cid].outcome.values[0] for cid in\n", + " X_train[:, 0, 0, 0]]).astype('float32')\n", + "\n", + "# Remove the case_admission_id, sample_label, and time_step_label columns from the data\n", + "X_train = X_train[:, :, :, -1].astype('float32')\n", + "X_train, y_train = aggregate_features_over_time(X_train, y_train, moving_average=moving_time_average)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## Load the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = pickle.load(open(saved_model_path, 'rb'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "explainer = shap.TreeExplainer(model)\n", + "shap_values = explainer.shap_values(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "features = list(test_features_lookup_table['sample_label'].keys())\n", + "features = np.concatenate([['last_tp_' + f for f in features], ['avg_' + f for f in features],\n", + " ['min' + f for f in features], ['max_' + f for f in features]])\n", + "features.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "shap_values.shape, X_test.shape, features.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "### Create working data frame\n", + "\n", + "Join data in a common dataframe with shap values and feature values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pd.DataFrame(data=shap_values, columns = features)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "selected_shap_values_df = pd.DataFrame(data=shap_values, columns = features)\n", + "selected_shap_values_df = selected_shap_values_df.reset_index()\n", + "selected_shap_values_df.rename(columns={'index': 'case_admission_id_idx'}, inplace=True)\n", + "selected_shap_values_df = selected_shap_values_df.melt(id_vars='case_admission_id_idx', var_name='feature', value_name='shap_value')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "selected_feature_values_df = pd.DataFrame(data=X_test, columns = features)\n", + "selected_feature_values_df = selected_feature_values_df.reset_index()\n", + "selected_feature_values_df.rename(columns={'index': 'case_admission_id_idx'}, inplace=True)\n", + "selected_feature_values_df = selected_feature_values_df.melt(id_vars='case_admission_id_idx', var_name='feature', value_name='feature_value')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "features_with_shap_values_df = pd.merge(selected_shap_values_df, selected_feature_values_df, on=['case_admission_id_idx', 'feature'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CURRENT POSITION IN CODE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reverse_categorical_encoding = True\n", + "\n", + "if reverse_categorical_encoding:\n", + " cat_encoding_df = pd.read_csv(cat_encoding_path)\n", + " for i in range(len(cat_encoding_df)):\n", + " cat_basename = cat_encoding_df.sample_label[i].lower().replace(' ', '_')\n", + " cat_item_list = cat_encoding_df.other_categories[i].replace('[', '').replace(']', '').replace('\\'', '').split(', ')\n", + " cat_item_list = [cat_basename + '_' + item.replace(' ', '_').lower() for item in cat_item_list]\n", + " for cat_item_idx, cat_item in enumerate(cat_item_list):\n", + " # retrieve the dominant category for this subject (0 being default category)\n", + " features_with_shap_values_df.loc[features_with_shap_values_df.feature == cat_item, 'feature_value'] *= cat_item_idx + 1\n", + " features_with_shap_values_df.loc[features_with_shap_values_df.feature == cat_item, 'feature'] = cat_encoding_df.sample_label[i]\n", + " # sum the shap and feature values for each subject\n", + " features_with_shap_values_df = features_with_shap_values_df.groupby(['case_admission_id_idx', 'feature']).sum().reset_index()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# give a numerical encoding to the categorical features\n", + "cat_to_numerical_encoding = {\n", + " 'Prestroke disability (Rankin)': {0:0, 1:5, 2:4, 3:2, 4:1, 5:3},\n", + " 'categorical_onset_to_admission_time': {0:1, 1:2, 2:3, 3:4, 4:0},\n", + " 'categorical_IVT': {0:2, 1:3, 2:4, 3:1, 4:0},\n", + " 'categorical_IAT': {0:1, 1:0, 2:3, 3:2}\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for cat_feature, cat_encoding in cat_to_numerical_encoding.items():\n", + " features_with_shap_values_df.loc[features_with_shap_values_df.feature == cat_feature, 'feature_value'] = features_with_shap_values_df.loc[features_with_shap_values_df.feature == cat_feature, 'feature_value'].map(cat_encoding)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pool_hourly_split_values = True\n", + "\n", + "# For features that are downsampled to hourly values, pool the values (median, min, max)\n", + "\n", + "if pool_hourly_split_values:\n", + " hourly_split_features = ['NIHSS', 'systolic_blood_pressure', 'diastolic_blood_pressure', 'heart_rate', 'respiratory_rate', 'temperature', 'oxygen_saturation']\n", + " for feature in hourly_split_features:\n", + " features_with_shap_values_df.loc[features_with_shap_values_df.feature.str.contains(feature), 'feature'] = (feature[0].upper() + feature[1:]\n", + ").replace('_', ' ')\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Replace feature names with their english names" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "feature_to_english_name_correspondence_path = os.path.join(os.path.dirname(\n", + " os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath('__file__'))))))),\n", + " 'preprocessing/preprocessing_tools/feature_name_to_english_name_correspondence.xlsx')\n", + "feature_to_english_name_correspondence = pd.read_excel(feature_to_english_name_correspondence_path)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for feature in features_with_shap_values_df.feature.unique():\n", + " if feature in feature_to_english_name_correspondence.feature_name.values:\n", + " features_with_shap_values_df.loc[features_with_shap_values_df.feature == feature, 'feature'] = feature_to_english_name_correspondence[feature_to_english_name_correspondence.feature_name == feature].english_name.values[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## Feature selection\n", + "\n", + "Select only the features that are in the top 10 most important features by mean absolute shap value" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# identify the top 10 most important features by mean absolute shap value\n", + "features_with_shap_values_df['absolute_shap_value'] = np.abs(features_with_shap_values_df['shap_value'])\n", + "top_10_features_by_mean_abs_summed_shap = features_with_shap_values_df.groupby('feature').mean().sort_values(by='absolute_shap_value', ascending=False).head(10).index.values\n", + "top_10_features_by_mean_abs_summed_shap" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "features_with_shap_values_df = features_with_shap_values_df[features_with_shap_values_df.feature.isin(top_10_features_by_mean_abs_summed_shap)]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Alternatively, features could also be selected before joining categories and pooling hourly values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ten_most_important_features_by_mean_abs_shap = np.abs(shap_values[0]).mean(axis=(0, 1)).argsort()[::-1][0:13]\n", + "np.array(features)[ten_most_important_features_by_mean_abs_shap]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## Create color palette for feature values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)\n", + "all_colors_palette" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "base_colors = sns.color_palette(['#f61067', '#012D98'], n_colors=2)\n", + "base_colors" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from prediction.utils.visualisation_helper_functions import hex_to_rgb_color, create_palette\n", + "from colormath.color_objects import sRGBColor, HSVColor, LabColor, LCHuvColor, XYZColor, LCHabColor, LuvColor\n", + "\n", + "start_color = '#012D98'\n", + "end_color = '#f61067'\n", + "\n", + "# start_color= '#049b9a'\n", + "# end_color= '#012D98'\n", + "\n", + "number_of_colors = 50\n", + "\n", + "start_rgb = hex_to_rgb_color(start_color)\n", + "end_rgb = hex_to_rgb_color(end_color)\n", + "\n", + "palette = create_palette(start_rgb, end_rgb, number_of_colors, LabColor, extrapolation_length=1)\n", + "custom_cmap = sns.color_palette(palette, n_colors=number_of_colors, as_cmap=True)\n", + "sns.color_palette(palette, n_colors=number_of_colors)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## Plot most important features with SHAP values\n", + "\n", + "Preqrequisites: pd.Dataframe with shap values and feature values for each feature, along with indexes for each case" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from matplotlib.colors import ListedColormap\n", + "import matplotlib.lines as mlines\n", + "from matplotlib.legend_handler import HandlerTuple\n", + "\n", + "plot_shap_direction_label = True\n", + "plot_legend = True\n", + "plot_colorbar = True\n", + "plot_feature_value_along_y = False\n", + "\n", + "tick_label_size = 11\n", + "label_font_size = 13\n", + "\n", + "row_height = 0.4\n", + "alpha = 0.8\n", + "\n", + "plt.gcf().set_size_inches(10, 10)\n", + "\n", + "\n", + "for pos, feature in enumerate(features_with_shap_values_df.feature.unique()):\n", + " shaps = features_with_shap_values_df[features_with_shap_values_df.feature.isin([feature])].shap_value.values\n", + " values = features_with_shap_values_df[features_with_shap_values_df.feature.isin([feature])].feature_value\n", + " plt.axhline(y=pos, color=\"#cccccc\", lw=0.5, dashes=(1, 5), zorder=-1)\n", + "\n", + " values = np.array(values, dtype=np.float64) # make sure this can be numeric\n", + "\n", + " N = len(shaps)\n", + " nbins = 100\n", + " quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))\n", + " inds = np.argsort(quant + np.random.randn(N) * 1e-6)\n", + " layer = 0\n", + " last_bin = -1\n", + "\n", + " if plot_feature_value_along_y:\n", + " ys = values.copy()\n", + " cluster_factor = 0.1\n", + " for ind in inds:\n", + " if quant[ind] != last_bin:\n", + " layer = 0\n", + " ys[ind] += cluster_factor * (np.ceil(layer / 2) * ((layer % 2) * 2 - 1))\n", + " layer += 1\n", + " last_bin = quant[ind]\n", + "\n", + " else:\n", + " ys = np.zeros(N)\n", + " cluster_factor = 1\n", + " for ind in inds:\n", + " if quant[ind] != last_bin:\n", + " layer = 0\n", + " ys[ind] = cluster_factor * (np.ceil(layer / 2) * ((layer % 2) * 2 - 1))\n", + " layer += 1\n", + " last_bin = quant[ind]\n", + "\n", + " ys *= 0.9 * (row_height / np.max(ys + 1))\n", + "\n", + " # trim the color range, but prevent the color range from collapsing\n", + " vmin = np.nanpercentile(values, 5)\n", + " vmax = np.nanpercentile(values, 95)\n", + " if vmin == vmax:\n", + " vmin = np.nanpercentile(values, 1)\n", + " vmax = np.nanpercentile(values, 99)\n", + " if vmin == vmax:\n", + " vmin = np.min(values)\n", + " vmax = np.max(values)\n", + " if vmin > vmax: # fixes rare numerical precision issues\n", + " vmin = vmax\n", + "\n", + " # plot the non-nan values colored by the trimmed feature value\n", + " cvals = values.astype(np.float64)\n", + " cvals_imp = cvals.copy()\n", + " cvals_imp[np.isnan(cvals)] = (vmin + vmax) / 2.0\n", + " cvals[cvals_imp > vmax] = vmax\n", + " cvals[cvals_imp < vmin] = vmin\n", + " plt.scatter(shaps, pos + ys,\n", + " cmap=ListedColormap(palette), vmin=vmin, vmax=vmax, s=16,\n", + " c=cvals, alpha=alpha, linewidth=0,\n", + " zorder=3, rasterized=len(shaps) > 500)\n", + "\n", + "\n", + "import matplotlib.cm as cm\n", + "\n", + "axis_color=\"#333333\"\n", + "if plot_colorbar:\n", + " m = cm.ScalarMappable(cmap=ListedColormap(palette))\n", + " m.set_array([0, 1])\n", + " cb = plt.colorbar(m, ticks=[0, 1], aspect=10, shrink=0.2)\n", + " cb.set_ticklabels(['Low', 'High'])\n", + " cb.ax.tick_params(labelsize=tick_label_size, length=0)\n", + " cb.set_label('Feature value', size=label_font_size)\n", + " cb.ax.yaxis.set_label_position('left')\n", + " cb.set_alpha(1)\n", + " cb.outline.set_visible(False)\n", + "\n", + "if plot_legend:\n", + " legend_markers = []\n", + " legend_labels = []\n", + " single_dot = mlines.Line2D([], [], color=palette[len(palette)//2], marker='.', linestyle='None',\n", + " markersize=10)\n", + " single_dot_label = 'Single Patient\\n(summed over time)'\n", + " legend_markers.append(single_dot)\n", + " legend_labels.append(single_dot_label)\n", + "\n", + " plt.gca().legend(legend_markers, legend_labels, title='SHAP/Feature values', fontsize=tick_label_size, title_fontsize=label_font_size,\n", + " handler_map={tuple: HandlerTuple(ndivide=None)},\n", + " loc='upper left', frameon=True)\n", + "\n", + "\n", + "plt.gca().xaxis.set_ticks_position('bottom')\n", + "plt.gca().yaxis.set_ticks_position('none')\n", + "plt.gca().spines['right'].set_visible(False)\n", + "plt.gca().spines['top'].set_visible(False)\n", + "plt.gca().spines['left'].set_visible(False)\n", + "plt.gca().tick_params(color=axis_color, labelcolor=axis_color)\n", + "\n", + "yticklabels = features_with_shap_values_df.feature.unique()\n", + "plt.yticks(range(len(features_with_shap_values_df.feature.unique())), yticklabels, fontsize=label_font_size)\n", + "plt.gca().tick_params('y', length=20, width=0.5, which='major')\n", + "plt.gca().tick_params('x', labelsize=tick_label_size)\n", + "plt.ylim(-1, len(features_with_shap_values_df.feature.unique()))\n", + "plt.xlabel('SHAP Value \\n(impact on model output)', fontsize=label_font_size)\n", + "plt.grid(color='white', axis='y')\n", + "\n", + "plt.xlim(-0.25, 0.15)\n", + "\n", + "# Plot additional explanation with the shap value X axis\n", + "if plot_shap_direction_label:\n", + " x_ticks_coordinates = plt.xticks()[0]\n", + " x_ticks_labels = [item.get_text() for item in plt.xticks()[1]]\n", + " # let x tick label be the coordinate with 2 decimals\n", + " x_ticks_labels = [f'{x_ticks_coordinate:.2f}' for x_ticks_coordinate in x_ticks_coordinates]\n", + " x_ticks_labels[0] = f'Toward worse \\noutcome'\n", + " x_ticks_labels[-1] = f'Toward better \\noutcome'\n", + " plt.xticks(x_ticks_coordinates, x_ticks_labels)\n", + "\n", + "fig = plt.gcf()\n", + "\n", + "plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig.savefig(os.path.join(output_dir, f'top_features_shap_{outcome}.svg'), bbox_inches=\"tight\", format='svg', dpi=1200)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/prediction/outcome_prediction/treeModel/hyperoptimisation_analysis.ipynb b/prediction/outcome_prediction/treeModel/hyperoptimisation_analysis.ipynb index 0edba75..ebd0c60 100644 --- a/prediction/outcome_prediction/treeModel/hyperoptimisation_analysis.ipynb +++ b/prediction/outcome_prediction/treeModel/hyperoptimisation_analysis.ipynb @@ -18,7 +18,7 @@ "metadata": {}, "outputs": [], "source": [ - "output_dir = '/Users/jk1/temp/opsum_prediction_output/linear_72h_xgb/with_feature_aggregration'" + "output_dir = '/Users/jk1/temp/opsum_prediction_output/linear_72h_xgb/moving_avg_feature_agg'" ] }, { @@ -188,6 +188,13 @@ "sns.violinplot(x='reg_lambda', y='auc_val', data=df)\n", "plt.show()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/prediction/outcome_prediction/treeModel/test_xgb.py b/prediction/outcome_prediction/treeModel/test_xgb.py index 7ecf62f..aa10178 100644 --- a/prediction/outcome_prediction/treeModel/test_xgb.py +++ b/prediction/outcome_prediction/treeModel/test_xgb.py @@ -10,9 +10,9 @@ from prediction.outcome_prediction.treeModel.feature_aggregration_xgboost import evaluate_model -def test_model(max_depth:int, learning_rate:float, n_estimators:int, reg_lambda:int, alpha:int, +def test_model(max_depth:int, learning_rate:float, n_estimators:int, reg_lambda:int, alpha:int, moving_average:bool, outcome:str, features_df_path:str, outcomes_df_path:str, output_dir:str): - optimal_model_df, trained_models, test_dataset = evaluate_model(max_depth, learning_rate, n_estimators, reg_lambda, alpha, + optimal_model_df, trained_models, test_dataset = evaluate_model(max_depth, learning_rate, n_estimators, reg_lambda, alpha, moving_average, outcome, features_df_path, outcomes_df_path, output_dir, save_models=True) X_test, y_test = test_dataset @@ -22,6 +22,11 @@ def test_model(max_depth:int, learning_rate:float, n_estimators:int, reg_lambda: best_cv_fold_idx = best_cv_fold - 1 selected_model = trained_models[best_cv_fold_idx] + # save selected model + model_path = os.path.join(output_dir, f'selected_xgb_model_cv{best_cv_fold}.pkl') + with open(model_path, 'wb') as f: + pickle.dump(selected_model, f) + # calculate overall model prediction y_pred_test = selected_model.predict_proba(X_test)[:, 1].astype('float32') @@ -138,8 +143,13 @@ def test_model(max_depth:int, learning_rate:float, n_estimators:int, reg_lambda: best_parameters_df = pd.read_csv(cli_args.best_parameters_path) + # check if moving_average exists in best_parameters_df + if 'moving_average' not in best_parameters_df.columns: + best_parameters_df['moving_average'] = False + result_df, bootstrapping_data, testing_data = test_model(int(best_parameters_df['max_depth'][0]), best_parameters_df['learning_rate'][0], int(best_parameters_df['n_estimators'][0]), best_parameters_df['reg_lambda'][0], best_parameters_df['alpha'][0], - cli_args.outcome, cli_args.feature_df_path, cli_args.outcome_df_path, cli_args.output_dir) + bool(best_parameters_df['moving_average'][0]), + cli_args.outcome, cli_args.feature_df_path, cli_args.outcome_df_path, cli_args.output_dir) result_df.to_csv(os.path.join(cli_args.output_dir, 'test_XGB_results.csv'), sep=',', index=False)