Skip to content

Commit

Permalink
xgb gridsearch eval
Browse files Browse the repository at this point in the history
  • Loading branch information
JulianKlug committed Sep 5, 2024
1 parent 5142956 commit 062597f
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
"id": "f7ba06e20a49bef9",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T12:03:55.467883Z",
"start_time": "2024-08-21T12:03:52.392947Z"
"end_time": "2024-08-30T20:36:27.137340Z",
"start_time": "2024-08-30T20:36:22.649204Z"
}
},
"outputs": [],
Expand All @@ -29,15 +29,16 @@
"id": "initial_id",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T12:08:28.374201Z",
"start_time": "2024-08-21T12:08:28.356427Z"
"end_time": "2024-08-30T20:38:08.135137Z",
"start_time": "2024-08-30T20:38:08.132496Z"
}
},
"outputs": [],
"source": [
"# data_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_08062024_083500/early_neurological_deterioration_train_data_splits/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'\n",
"data_path = '/Users/jk1/Downloads/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'\n",
"model_path = '/Users/jk1/temp/opsum_end/training/hyperopt/gridsearch/best/checkpoints_short_opsum_transformer_20240814_073845_cv_0/short_opsum_transformer_epoch=07_val_auroc=0.8399.ckpt'"
"data_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_08062024_083500/early_neurological_deterioration_train_data_splits/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'\n",
"# data_path = '/Users/jk1/Downloads/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'\n",
"model_path = '/Users/jk1/temp/opsum_end/training/hyperopt/gridsearch/best/checkpoints_short_opsum_transformer_20240814_073845_cv_0/short_opsum_transformer_epoch=07_val_auroc=0.8399.ckpt'\n",
"predictions_path = '/Users/jk1/Downloads/predictions.pt'"
]
},
{
Expand All @@ -46,8 +47,8 @@
"id": "bf0c7d6ca65b6f0b",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T12:08:28.701093Z",
"start_time": "2024-08-21T12:08:28.697551Z"
"end_time": "2024-08-30T20:36:27.145744Z",
"start_time": "2024-08-30T20:36:27.143941Z"
}
},
"outputs": [],
Expand All @@ -62,8 +63,8 @@
"id": "2dce32a69246812e",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T12:08:30.843398Z",
"start_time": "2024-08-21T12:08:30.837564Z"
"end_time": "2024-08-30T20:36:27.150058Z",
"start_time": "2024-08-30T20:36:27.147379Z"
}
},
"outputs": [],
Expand All @@ -77,35 +78,55 @@
"id": "5efbdbb0d769aa26",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T12:08:32.112582Z",
"start_time": "2024-08-21T12:08:31.082577Z"
"end_time": "2024-08-30T20:41:07.955659Z",
"start_time": "2024-08-30T20:38:11.307171Z"
}
},
"outputs": [],
"source": [
"splits = ch.load(os.path.join(data_path))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f904dcf12dc92e1",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-30T20:47:05.545Z",
"start_time": "2024-08-30T20:47:05.542470Z"
}
},
"outputs": [],
"source": [
"best_cv_fold = 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c691b4eb5b9445a",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T12:08:33.149945Z",
"start_time": "2024-08-21T12:08:33.142490Z"
"end_time": "2024-08-30T20:47:05.724664Z",
"start_time": "2024-08-30T20:47:05.721719Z"
}
},
"outputs": [],
"source": [
"full_X_train, full_X_val, y_train, y_val = splits[0]"
"full_X_train, full_X_val, y_train, y_val = splits[best_cv_fold]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "de8c6d9d776707d7",
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-30T20:47:07.454197Z",
"start_time": "2024-08-30T20:47:06.103251Z"
}
},
"outputs": [],
"source": [
"from sklearn.preprocessing import StandardScaler\n",
Expand All @@ -123,7 +144,12 @@
"cell_type": "code",
"execution_count": null,
"id": "1ba7844ff708aaa3",
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-30T20:47:07.459483Z",
"start_time": "2024-08-30T20:47:07.456237Z"
}
},
"outputs": [],
"source": [
"X_val.shape"
Expand Down Expand Up @@ -248,6 +274,36 @@
"pred_over_ts_np.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "36cac6350dc4c980",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-30T20:37:46.204476Z",
"start_time": "2024-08-30T20:37:46.198438Z"
}
},
"outputs": [],
"source": [
"predictions_data = ch.load(predictions_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ae5bb5f725d8d4fa",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-30T20:49:28.710916Z",
"start_time": "2024-08-30T20:49:28.708083Z"
}
},
"outputs": [],
"source": [
"pred_over_ts_np = np.squeeze(predictions_data).T"
]
},
{
"cell_type": "markdown",
"id": "d139cd87e9791456",
Expand All @@ -264,8 +320,8 @@
"id": "aeb58a810e199a6",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T12:10:46.404588Z",
"start_time": "2024-08-21T12:10:46.362979Z"
"end_time": "2024-08-30T20:49:30.624598Z",
"start_time": "2024-08-30T20:49:30.583680Z"
}
},
"outputs": [],
Expand Down Expand Up @@ -293,7 +349,12 @@
"cell_type": "code",
"execution_count": null,
"id": "b9194088aac8b132",
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-30T20:49:34.220412Z",
"start_time": "2024-08-30T20:49:33.705901Z"
}
},
"outputs": [],
"source": [
"# compute roc scores for each time step\n",
Expand All @@ -314,7 +375,12 @@
"cell_type": "code",
"execution_count": null,
"id": "d2fa82defc65a470",
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-30T20:49:34.744525Z",
"start_time": "2024-08-30T20:49:34.739255Z"
}
},
"outputs": [],
"source": [
"np.nanmedian(roc_scores)"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "initial_id",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-23T10:34:57.630025Z",
"start_time": "2024-08-23T10:34:57.627164Z"
}
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import os\n",
"import seaborn as sns"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b7dfd825b243cf65",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-23T10:35:30.472389Z",
"start_time": "2024-08-23T10:35:30.468200Z"
}
},
"outputs": [],
"source": [
"log_folder_path = '/Users/jk1/Downloads/xgb_gs'\n",
"output_dir = '/Users/jk1/Downloads'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5df24c93e6096cfd",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-23T10:35:46.517467Z",
"start_time": "2024-08-23T10:35:46.374150Z"
}
},
"outputs": [],
"source": [
"gs_df = pd.DataFrame()\n",
"for root, dirs, files in os.walk(log_folder_path):\n",
" for file in files:\n",
" if file.endswith('.jsonl'):\n",
" temp_df = pd.read_json(os.path.join(root, file), \n",
" lines=True, dtype={'timestamp': 'object'}, convert_dates=False).drop(0)\n",
" # add file name as column\n",
" temp_df['file_name'] = file\n",
" gs_df = pd.concat([gs_df, temp_df], ignore_index=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ce6cba7b03680cb",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-23T10:35:49.821820Z",
"start_time": "2024-08-23T10:35:49.801767Z"
}
},
"outputs": [],
"source": [
"gs_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b2e38efb0e502dd",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-23T10:36:20.053438Z",
"start_time": "2024-08-23T10:36:20.040408Z"
}
},
"outputs": [],
"source": [
"best_df = gs_df.sort_values('median_val_scores', ascending=False).head(1)\n",
"best_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fba22c4b0b816d4c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 062597f

Please sign in to comment.