From 062597ff53845b7d896f6ca43e0dfb308d815aa3 Mon Sep 17 00:00:00 2001 From: Julian Klug Date: Thu, 5 Sep 2024 15:22:06 +0200 Subject: [PATCH] xgb gridsearch eval --- .../val_evaluation_exploration.ipynb | 110 ++++++++++++---- .../xgb_gridsearch_evaluation.ipynb | 121 ++++++++++++++++++ 2 files changed, 209 insertions(+), 22 deletions(-) create mode 100644 prediction/short_term_outcome_prediction/explorations/xgb_gridsearch_evaluation.ipynb diff --git a/prediction/short_term_outcome_prediction/explorations/val_evaluation_exploration.ipynb b/prediction/short_term_outcome_prediction/explorations/val_evaluation_exploration.ipynb index 2691dd3..e9a005c 100644 --- a/prediction/short_term_outcome_prediction/explorations/val_evaluation_exploration.ipynb +++ b/prediction/short_term_outcome_prediction/explorations/val_evaluation_exploration.ipynb @@ -6,8 +6,8 @@ "id": "f7ba06e20a49bef9", "metadata": { "ExecuteTime": { - "end_time": "2024-08-21T12:03:55.467883Z", - "start_time": "2024-08-21T12:03:52.392947Z" + "end_time": "2024-08-30T20:36:27.137340Z", + "start_time": "2024-08-30T20:36:22.649204Z" } }, "outputs": [], @@ -29,15 +29,16 @@ "id": "initial_id", "metadata": { "ExecuteTime": { - "end_time": "2024-08-21T12:08:28.374201Z", - "start_time": "2024-08-21T12:08:28.356427Z" + "end_time": "2024-08-30T20:38:08.135137Z", + "start_time": "2024-08-30T20:38:08.132496Z" } }, "outputs": [], "source": [ - "# data_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_08062024_083500/early_neurological_deterioration_train_data_splits/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'\n", - "data_path = '/Users/jk1/Downloads/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'\n", - "model_path = '/Users/jk1/temp/opsum_end/training/hyperopt/gridsearch/best/checkpoints_short_opsum_transformer_20240814_073845_cv_0/short_opsum_transformer_epoch=07_val_auroc=0.8399.ckpt'" + "data_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_08062024_083500/early_neurological_deterioration_train_data_splits/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'\n", + "# data_path = '/Users/jk1/Downloads/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'\n", + "model_path = '/Users/jk1/temp/opsum_end/training/hyperopt/gridsearch/best/checkpoints_short_opsum_transformer_20240814_073845_cv_0/short_opsum_transformer_epoch=07_val_auroc=0.8399.ckpt'\n", + "predictions_path = '/Users/jk1/Downloads/predictions.pt'" ] }, { @@ -46,8 +47,8 @@ "id": "bf0c7d6ca65b6f0b", "metadata": { "ExecuteTime": { - "end_time": "2024-08-21T12:08:28.701093Z", - "start_time": "2024-08-21T12:08:28.697551Z" + "end_time": "2024-08-30T20:36:27.145744Z", + "start_time": "2024-08-30T20:36:27.143941Z" } }, "outputs": [], @@ -62,8 +63,8 @@ "id": "2dce32a69246812e", "metadata": { "ExecuteTime": { - "end_time": "2024-08-21T12:08:30.843398Z", - "start_time": "2024-08-21T12:08:30.837564Z" + "end_time": "2024-08-30T20:36:27.150058Z", + "start_time": "2024-08-30T20:36:27.147379Z" } }, "outputs": [], @@ -77,8 +78,8 @@ "id": "5efbdbb0d769aa26", "metadata": { "ExecuteTime": { - "end_time": "2024-08-21T12:08:32.112582Z", - "start_time": "2024-08-21T12:08:31.082577Z" + "end_time": "2024-08-30T20:41:07.955659Z", + "start_time": "2024-08-30T20:38:11.307171Z" } }, "outputs": [], @@ -86,26 +87,46 @@ "splits = ch.load(os.path.join(data_path))" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f904dcf12dc92e1", + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-30T20:47:05.545Z", + "start_time": "2024-08-30T20:47:05.542470Z" + } + }, + "outputs": [], + "source": [ + "best_cv_fold = 0" + ] + }, { "cell_type": "code", "execution_count": null, "id": "9c691b4eb5b9445a", "metadata": { "ExecuteTime": { - "end_time": "2024-08-21T12:08:33.149945Z", - "start_time": "2024-08-21T12:08:33.142490Z" + "end_time": "2024-08-30T20:47:05.724664Z", + "start_time": "2024-08-30T20:47:05.721719Z" } }, "outputs": [], "source": [ - "full_X_train, full_X_val, y_train, y_val = splits[0]" + "full_X_train, full_X_val, y_train, y_val = splits[best_cv_fold]" ] }, { "cell_type": "code", "execution_count": null, "id": "de8c6d9d776707d7", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-30T20:47:07.454197Z", + "start_time": "2024-08-30T20:47:06.103251Z" + } + }, "outputs": [], "source": [ "from sklearn.preprocessing import StandardScaler\n", @@ -123,7 +144,12 @@ "cell_type": "code", "execution_count": null, "id": "1ba7844ff708aaa3", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-30T20:47:07.459483Z", + "start_time": "2024-08-30T20:47:07.456237Z" + } + }, "outputs": [], "source": [ "X_val.shape" @@ -248,6 +274,36 @@ "pred_over_ts_np.shape" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "36cac6350dc4c980", + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-30T20:37:46.204476Z", + "start_time": "2024-08-30T20:37:46.198438Z" + } + }, + "outputs": [], + "source": [ + "predictions_data = ch.load(predictions_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae5bb5f725d8d4fa", + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-30T20:49:28.710916Z", + "start_time": "2024-08-30T20:49:28.708083Z" + } + }, + "outputs": [], + "source": [ + "pred_over_ts_np = np.squeeze(predictions_data).T" + ] + }, { "cell_type": "markdown", "id": "d139cd87e9791456", @@ -264,8 +320,8 @@ "id": "aeb58a810e199a6", "metadata": { "ExecuteTime": { - "end_time": "2024-08-21T12:10:46.404588Z", - "start_time": "2024-08-21T12:10:46.362979Z" + "end_time": "2024-08-30T20:49:30.624598Z", + "start_time": "2024-08-30T20:49:30.583680Z" } }, "outputs": [], @@ -293,7 +349,12 @@ "cell_type": "code", "execution_count": null, "id": "b9194088aac8b132", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-30T20:49:34.220412Z", + "start_time": "2024-08-30T20:49:33.705901Z" + } + }, "outputs": [], "source": [ "# compute roc scores for each time step\n", @@ -314,7 +375,12 @@ "cell_type": "code", "execution_count": null, "id": "d2fa82defc65a470", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-30T20:49:34.744525Z", + "start_time": "2024-08-30T20:49:34.739255Z" + } + }, "outputs": [], "source": [ "np.nanmedian(roc_scores)" diff --git a/prediction/short_term_outcome_prediction/explorations/xgb_gridsearch_evaluation.ipynb b/prediction/short_term_outcome_prediction/explorations/xgb_gridsearch_evaluation.ipynb new file mode 100644 index 0000000..b2116f5 --- /dev/null +++ b/prediction/short_term_outcome_prediction/explorations/xgb_gridsearch_evaluation.ipynb @@ -0,0 +1,121 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "initial_id", + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-23T10:34:57.630025Z", + "start_time": "2024-08-23T10:34:57.627164Z" + } + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import os\n", + "import seaborn as sns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7dfd825b243cf65", + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-23T10:35:30.472389Z", + "start_time": "2024-08-23T10:35:30.468200Z" + } + }, + "outputs": [], + "source": [ + "log_folder_path = '/Users/jk1/Downloads/xgb_gs'\n", + "output_dir = '/Users/jk1/Downloads'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5df24c93e6096cfd", + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-23T10:35:46.517467Z", + "start_time": "2024-08-23T10:35:46.374150Z" + } + }, + "outputs": [], + "source": [ + "gs_df = pd.DataFrame()\n", + "for root, dirs, files in os.walk(log_folder_path):\n", + " for file in files:\n", + " if file.endswith('.jsonl'):\n", + " temp_df = pd.read_json(os.path.join(root, file), \n", + " lines=True, dtype={'timestamp': 'object'}, convert_dates=False).drop(0)\n", + " # add file name as column\n", + " temp_df['file_name'] = file\n", + " gs_df = pd.concat([gs_df, temp_df], ignore_index=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce6cba7b03680cb", + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-23T10:35:49.821820Z", + "start_time": "2024-08-23T10:35:49.801767Z" + } + }, + "outputs": [], + "source": [ + "gs_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2e38efb0e502dd", + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-23T10:36:20.053438Z", + "start_time": "2024-08-23T10:36:20.040408Z" + } + }, + "outputs": [], + "source": [ + "best_df = gs_df.sort_values('median_val_scores', ascending=False).head(1)\n", + "best_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fba22c4b0b816d4c", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}