-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0c2b16b
commit 5142956
Showing
7 changed files
with
594 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
351 changes: 351 additions & 0 deletions
351
prediction/short_term_outcome_prediction/explorations/xgb_exploration.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,351 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "9f49f2fed1db301f", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-21T18:45:21.974599Z", | ||
"start_time": "2024-08-21T18:45:21.972027Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import pandas as pd\n", | ||
"import torch as ch\n", | ||
"import numpy as np\n", | ||
"from os import path" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "initial_id", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-21T18:45:23.700015Z", | ||
"start_time": "2024-08-21T18:45:23.484876Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"data_splits_path = '/Users/jk1/Downloads/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'\n", | ||
"splits = ch.load(path.join(data_splits_path))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "3e77fe7ddfdd50a2", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-21T19:00:09.498960Z", | ||
"start_time": "2024-08-21T19:00:09.496533Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"x_train = splits[0][0]\n", | ||
"y_train = splits[0][2]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "eedaf52fae95d516", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-21T18:52:25.957773Z", | ||
"start_time": "2024-08-21T18:52:25.951716Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"x_train.shape, y_train.shape" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "59da7c4747c285e9", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-21T18:52:31.255138Z", | ||
"start_time": "2024-08-21T18:52:31.242785Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"y_train" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d56b8768aab5d159", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-21T18:53:13.070358Z", | ||
"start_time": "2024-08-21T18:53:13.054709Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from prediction.utils.utils import aggregate_features_over_time\n", | ||
"\n", | ||
"x_train = x_train[:, :, :, -1].astype('float32')\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "19448753e6e00edd", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-21T18:54:14.756067Z", | ||
"start_time": "2024-08-21T18:54:14.750533Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# aggregate features over time so that one timepoint is one sample\n", | ||
"fold_X_train, fold_y_train = aggregate_features_over_time(x_train, np.array([None]), moving_average=False)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c955188932d756a6", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-21T18:54:15.756309Z", | ||
"start_time": "2024-08-21T18:54:15.753026Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"fold_X_train.shape, fold_y_train.shape" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a90f803d21159d77", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-21T19:00:13.885711Z", | ||
"start_time": "2024-08-21T19:00:13.873698Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from prediction.short_term_outcome_prediction.timeseries_decomposition import decompose_and_label_timeseries\n", | ||
"\n", | ||
"map, flat_labels = decompose_and_label_timeseries(x_train, y_train)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c3d8b7a8e6ed2ad5", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-21T19:00:15.836770Z", | ||
"start_time": "2024-08-21T19:00:15.821116Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"map" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "66a3fe706147638e", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-21T19:03:24.271876Z", | ||
"start_time": "2024-08-21T19:03:24.268266Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"timeseries = splits[0][0]\n", | ||
"y_df = splits[0][2]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b0afc8bd05bfe118", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-21T19:47:10.969633Z", | ||
"start_time": "2024-08-21T19:47:10.962431Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"def aggregate_and_label_timeseries(timeseries, y_df, target_time_to_outcome=6, mask_after_first_positive=True):\n", | ||
" \n", | ||
" all_subj_labels = []\n", | ||
" all_subj_data = []\n", | ||
" n_timepoints = timeseries.shape[1]\n", | ||
" for idx, cid in enumerate(timeseries[:, 0, 0, 0]):\n", | ||
" x_data = timeseries[None, idx, :, :, -1].astype('float32')\n", | ||
" if cid not in y_df.case_admission_id.values:\n", | ||
" labels = np.zeros(n_timepoints)\n", | ||
" else:\n", | ||
" event_ts = int(y_df[y_df.case_admission_id == cid].relative_sample_date_hourly_cat.values[0])\n", | ||
" # let labels be 0 until 6 ts before the event then 1 until the end then 0\n", | ||
" n_pos_start = max(0, event_ts - target_time_to_outcome)\n", | ||
" n_pos_end = event_ts\n", | ||
" labels = np.concatenate((np.zeros(n_pos_start), np.ones(n_pos_end - n_pos_start), np.zeros(n_timepoints - n_pos_end)))\n", | ||
" \n", | ||
" if mask_after_first_positive:\n", | ||
" labels = labels[:n_pos_start + 1]\n", | ||
" x_data = x_data[:, :n_pos_start + 1, :]\n", | ||
" x_data, _ = aggregate_features_over_time(x_data, np.array([None]), moving_average=False)\n", | ||
" all_subj_labels.append(labels)\n", | ||
" all_subj_data.append(x_data)\n", | ||
" \n", | ||
" return all_subj_data, all_subj_labels" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "da1a2ed34f9ed308", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-21T19:47:11.454876Z", | ||
"start_time": "2024-08-21T19:47:11.448501Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from sklearn.preprocessing import StandardScaler\n", | ||
"\n", | ||
"\n", | ||
"def prepare_aggregate_dataset(scenario, rescale=True, target_time_to_outcome=6, mask_after_first_positive=True):\n", | ||
" \"\"\"\n", | ||
" Prepares the dataset as an aggregate dataset (one sample per timepoint) and returns the train and validation sets.\n", | ||
"\n", | ||
" Args:\n", | ||
" scenario (tuple): tuple of (X_train, X_val, y_train, y_val)\n", | ||
" rescale (bool): whether to rescale the data or not\n", | ||
" target_time_to_outcome (int): number of timesteps to predict in the future\n", | ||
" \"\"\"\n", | ||
" X_train, X_val, y_train, y_val = scenario\n", | ||
"\n", | ||
" train_data, train_labels = aggregate_and_label_timeseries(X_train, y_train, target_time_to_outcome, mask_after_first_positive)\n", | ||
" val_data, val_labels = aggregate_and_label_timeseries(X_val, y_val, target_time_to_outcome, mask_after_first_positive)\n", | ||
" \n", | ||
" train_data = np.concatenate(train_data)\n", | ||
" train_labels = np.concatenate(train_labels)\n", | ||
" \n", | ||
" val_data = np.concatenate(val_data)\n", | ||
" val_labels = np.concatenate(val_labels)\n", | ||
" \n", | ||
" scaler = StandardScaler()\n", | ||
" if rescale:\n", | ||
" train_data = scaler.fit_transform(train_data)\n", | ||
" val_data = scaler.transform(val_data)\n", | ||
" \n", | ||
" \n", | ||
" return train_data, val_data, train_labels, val_labels\n", | ||
" " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "1a47aff06b4d69d", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-21T19:47:17.386503Z", | ||
"start_time": "2024-08-21T19:47:17.344368Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"all_datasets = [prepare_aggregate_dataset(x, rescale=True, target_time_to_outcome=6, mask_after_first_positive=True) for x in splits]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d48b128c0138b88e", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-21T19:47:17.843212Z", | ||
"start_time": "2024-08-21T19:47:17.839717Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"splits[0][0].shape, splits[0][2].shape" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2d6ceada21542003", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-21T19:47:18.707699Z", | ||
"start_time": "2024-08-21T19:47:18.704202Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"all_datasets[0][0].shape, all_datasets[0][2].shape" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f4bc9691092bd8bf", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-21T19:46:54.651002Z", | ||
"start_time": "2024-08-21T19:46:54.647111Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f1ad0faa36ad39c8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.