Skip to content

Commit

Permalink
aggregate xgb for end prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
JulianKlug committed Aug 21, 2024
1 parent 0c2b16b commit 5142956
Show file tree
Hide file tree
Showing 7 changed files with 594 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,23 +114,22 @@ def training_step(self, batch, batch_idx, mode='train'):

loss = self.criterion(predictions, y)

self.train_cos_sim(predictions, y)
self.train_cos_sim_epoch(predictions, y)
self.train_cos_sim(predictions.reshape(x.shape[0],-1), y.reshape(x.shape[0],-1))
self.train_cos_sim_epoch(predictions.reshape(x.shape[0],-1), y.reshape(x.shape[0],-1))

self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
self.log("train_cos_sim", self.train_cos_sim_epoch, on_step=False, on_epoch=True, prog_bar=True)

return loss

def validation_step(self, batch, batch_idx, mode='train'):
print('Validation step', batch_idx)
x, y = batch
y_input = x[:, -1, :][:, None, :]
predictions = self.model(x, y_input)

loss = self.criterion(predictions, y)

self.val_cos_sim_epoch(predictions, y)
self.val_cos_sim(predictions.reshape(x.shape[0],-1), y.reshape(x.shape[0],-1))

self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
self.log("val_cos_sim", self.val_cos_sim_epoch, on_step=False, on_epoch=True, prog_bar=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,11 @@
module = LitEncoderDecoderModel(model, lr, wd, train_noise, lr_warmup_steps=n_lr_warm_up_steps)
trainer = pl.Trainer(accelerator=accelerator, devices=1, max_epochs=max_epochs,
logger=[logger, pl.loggers.TensorBoardLogger(output_folder, name=f'{timestamp}_cv_{i}')],
log_every_n_steps=1, enable_checkpointing=True,
log_every_n_steps=50, enable_checkpointing=False,
callbacks=[MyEarlyStopping(step_limit=early_stopping_step_limit, metric='val_cos_sim'),
# checkpoint_callback],
],
gradient_clip_val=grad_clip,
num_sanity_val_steps=0)
gradient_clip_val=grad_clip)
trainer.fit(model=module, train_dataloaders=train_loader, val_dataloaders=val_loader)
# trainer.validate(model=module, dataloaders=val_loader)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "9f49f2fed1db301f",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T18:45:21.974599Z",
"start_time": "2024-08-21T18:45:21.972027Z"
}
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import torch as ch\n",
"import numpy as np\n",
"from os import path"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "initial_id",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T18:45:23.700015Z",
"start_time": "2024-08-21T18:45:23.484876Z"
}
},
"outputs": [],
"source": [
"data_splits_path = '/Users/jk1/Downloads/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'\n",
"splits = ch.load(path.join(data_splits_path))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3e77fe7ddfdd50a2",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T19:00:09.498960Z",
"start_time": "2024-08-21T19:00:09.496533Z"
}
},
"outputs": [],
"source": [
"x_train = splits[0][0]\n",
"y_train = splits[0][2]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eedaf52fae95d516",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T18:52:25.957773Z",
"start_time": "2024-08-21T18:52:25.951716Z"
}
},
"outputs": [],
"source": [
"x_train.shape, y_train.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "59da7c4747c285e9",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T18:52:31.255138Z",
"start_time": "2024-08-21T18:52:31.242785Z"
}
},
"outputs": [],
"source": [
"y_train"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d56b8768aab5d159",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T18:53:13.070358Z",
"start_time": "2024-08-21T18:53:13.054709Z"
}
},
"outputs": [],
"source": [
"from prediction.utils.utils import aggregate_features_over_time\n",
"\n",
"x_train = x_train[:, :, :, -1].astype('float32')\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19448753e6e00edd",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T18:54:14.756067Z",
"start_time": "2024-08-21T18:54:14.750533Z"
}
},
"outputs": [],
"source": [
"# aggregate features over time so that one timepoint is one sample\n",
"fold_X_train, fold_y_train = aggregate_features_over_time(x_train, np.array([None]), moving_average=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c955188932d756a6",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T18:54:15.756309Z",
"start_time": "2024-08-21T18:54:15.753026Z"
}
},
"outputs": [],
"source": [
"fold_X_train.shape, fold_y_train.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a90f803d21159d77",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T19:00:13.885711Z",
"start_time": "2024-08-21T19:00:13.873698Z"
}
},
"outputs": [],
"source": [
"from prediction.short_term_outcome_prediction.timeseries_decomposition import decompose_and_label_timeseries\n",
"\n",
"map, flat_labels = decompose_and_label_timeseries(x_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c3d8b7a8e6ed2ad5",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T19:00:15.836770Z",
"start_time": "2024-08-21T19:00:15.821116Z"
}
},
"outputs": [],
"source": [
"map"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "66a3fe706147638e",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T19:03:24.271876Z",
"start_time": "2024-08-21T19:03:24.268266Z"
}
},
"outputs": [],
"source": [
"timeseries = splits[0][0]\n",
"y_df = splits[0][2]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0afc8bd05bfe118",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T19:47:10.969633Z",
"start_time": "2024-08-21T19:47:10.962431Z"
}
},
"outputs": [],
"source": [
"def aggregate_and_label_timeseries(timeseries, y_df, target_time_to_outcome=6, mask_after_first_positive=True):\n",
" \n",
" all_subj_labels = []\n",
" all_subj_data = []\n",
" n_timepoints = timeseries.shape[1]\n",
" for idx, cid in enumerate(timeseries[:, 0, 0, 0]):\n",
" x_data = timeseries[None, idx, :, :, -1].astype('float32')\n",
" if cid not in y_df.case_admission_id.values:\n",
" labels = np.zeros(n_timepoints)\n",
" else:\n",
" event_ts = int(y_df[y_df.case_admission_id == cid].relative_sample_date_hourly_cat.values[0])\n",
" # let labels be 0 until 6 ts before the event then 1 until the end then 0\n",
" n_pos_start = max(0, event_ts - target_time_to_outcome)\n",
" n_pos_end = event_ts\n",
" labels = np.concatenate((np.zeros(n_pos_start), np.ones(n_pos_end - n_pos_start), np.zeros(n_timepoints - n_pos_end)))\n",
" \n",
" if mask_after_first_positive:\n",
" labels = labels[:n_pos_start + 1]\n",
" x_data = x_data[:, :n_pos_start + 1, :]\n",
" x_data, _ = aggregate_features_over_time(x_data, np.array([None]), moving_average=False)\n",
" all_subj_labels.append(labels)\n",
" all_subj_data.append(x_data)\n",
" \n",
" return all_subj_data, all_subj_labels"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "da1a2ed34f9ed308",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T19:47:11.454876Z",
"start_time": "2024-08-21T19:47:11.448501Z"
}
},
"outputs": [],
"source": [
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"\n",
"def prepare_aggregate_dataset(scenario, rescale=True, target_time_to_outcome=6, mask_after_first_positive=True):\n",
" \"\"\"\n",
" Prepares the dataset as an aggregate dataset (one sample per timepoint) and returns the train and validation sets.\n",
"\n",
" Args:\n",
" scenario (tuple): tuple of (X_train, X_val, y_train, y_val)\n",
" rescale (bool): whether to rescale the data or not\n",
" target_time_to_outcome (int): number of timesteps to predict in the future\n",
" \"\"\"\n",
" X_train, X_val, y_train, y_val = scenario\n",
"\n",
" train_data, train_labels = aggregate_and_label_timeseries(X_train, y_train, target_time_to_outcome, mask_after_first_positive)\n",
" val_data, val_labels = aggregate_and_label_timeseries(X_val, y_val, target_time_to_outcome, mask_after_first_positive)\n",
" \n",
" train_data = np.concatenate(train_data)\n",
" train_labels = np.concatenate(train_labels)\n",
" \n",
" val_data = np.concatenate(val_data)\n",
" val_labels = np.concatenate(val_labels)\n",
" \n",
" scaler = StandardScaler()\n",
" if rescale:\n",
" train_data = scaler.fit_transform(train_data)\n",
" val_data = scaler.transform(val_data)\n",
" \n",
" \n",
" return train_data, val_data, train_labels, val_labels\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1a47aff06b4d69d",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T19:47:17.386503Z",
"start_time": "2024-08-21T19:47:17.344368Z"
}
},
"outputs": [],
"source": [
"all_datasets = [prepare_aggregate_dataset(x, rescale=True, target_time_to_outcome=6, mask_after_first_positive=True) for x in splits]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d48b128c0138b88e",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T19:47:17.843212Z",
"start_time": "2024-08-21T19:47:17.839717Z"
}
},
"outputs": [],
"source": [
"splits[0][0].shape, splits[0][2].shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2d6ceada21542003",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T19:47:18.707699Z",
"start_time": "2024-08-21T19:47:18.704202Z"
}
},
"outputs": [],
"source": [
"all_datasets[0][0].shape, all_datasets[0][2].shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f4bc9691092bd8bf",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-21T19:46:54.651002Z",
"start_time": "2024-08-21T19:46:54.647111Z"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1ad0faa36ad39c8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit 5142956

Please sign in to comment.