Skip to content

Commit

Permalink
Timeseries prediction using Encoder/Decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
JulianKlug committed Aug 21, 2024
1 parent 217649a commit 439094d
Show file tree
Hide file tree
Showing 15 changed files with 1,147 additions and 111 deletions.
42 changes: 42 additions & 0 deletions prediction/outcome_prediction/Transformer/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,45 @@ def forward(self, x):
x = self.encoder(x)
x = self.classifier(x)
return x


class OPSUM_encoder_decoder(nn.Module):
def __init__(self, input_dim, num_layers, num_decoder_layers, model_dim, ff_dim,
num_heads, dropout, pos_encode_factor=0.1, layer_norm_eps = 1e-05, n_tokens=1,
max_dim=5000):
super().__init__()
self.embedder = nn.Linear(input_dim, model_dim)
self.pe = PositionalEncoding(model_dim, dropout, max_dim, factor=pos_encode_factor)
model_dim *= 2

self.transformer = ch.nn.Transformer(d_model=model_dim, batch_first=True,
nhead = num_heads, num_encoder_layers = num_layers, num_decoder_layers = num_decoder_layers,
dim_feedforward = ff_dim,
dropout = dropout, layer_norm_eps = layer_norm_eps
)
self.feature_linear = nn.Linear(model_dim, input_dim)
self.step_linear = nn.Linear(n_tokens+1, n_tokens)



def forward(self, x, tgt):
bs, t, f = x.shape
x = self.embedder(x.reshape(-1, f))
x = x.reshape(bs, t, -1)
x = self.pe(x)

tgt_bs, tgt_t, tgt_f = tgt.shape
tgt = self.embedder(tgt.reshape(-1, tgt_f))
tgt = tgt.reshape(tgt_bs, tgt_t, -1)
tgt = self.pe(tgt)

x = self.transformer(x, tgt)

# reduce number of features and timesteps
out = self.feature_linear(x)
# final_bs, final_t, final_f = x.shape
# x = x.reshape(final_bs, final_f, -1)
# x = self.step_linear(x)
# out = x.reshape(final_bs, -1, final_f)

return out
95 changes: 95 additions & 0 deletions prediction/outcome_prediction/Transformer/lightning_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import pytorch_lightning as pl
from torchmetrics import AUROC
from torchmetrics.classification import Accuracy
from torchmetrics.regression import CosineSimilarity

from prediction.outcome_prediction.Transformer.architecture import OPSUM_encoder_decoder


class LitModel(pl.LightningModule):
Expand Down Expand Up @@ -79,3 +82,95 @@ def warmup(current_step: int):
'frequency': 1
}
]


class LitEncoderDecoderModel(pl.LightningModule):
def __init__(self, model, lr, wd, train_noise, lr_warmup_steps=0):
super().__init__()
self.model = model
self.lr = lr
self.lr_warmup_steps = lr_warmup_steps
self.wd = wd
self.train_noise = train_noise

self.criterion = ch.nn.MSELoss()

self.train_cos_sim = CosineSimilarity()
self.train_cos_sim_epoch = CosineSimilarity()
self.val_cos_sim_epoch = CosineSimilarity()



def training_step(self, batch, batch_idx, mode='train'):
x, y = batch
if self.train_noise != 0:
# x = x + ch.randn_like(x) * self.train_noise
x = x + ch.randn(x.shape[0], x.shape[1], device=x.device)[:, :, None].repeat(1, 1, x.shape[2]) * self.train_noise

# y_input is last step of x
y_input = x[:, -1, :][:, None, :]

predictions = self.model(x, y_input)

loss = self.criterion(predictions, y)

self.train_cos_sim(predictions.ravel(), y.ravel())
self.train_cos_sim_epoch(predictions.ravel(), y.ravel())

self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
self.log("train_cos_sim", self.train_cos_sim_epoch, on_step=False, on_epoch=True, prog_bar=True)

return loss

def validation_step(self, batch, batch_idx, mode='train'):
print('Validation step', batch_idx)
x, y = batch
y_input = x[:, -1, :][:, None, :]
predictions = self.model(x, y_input)

loss = self.criterion(predictions, y)

self.val_cos_sim_epoch(predictions.ravel(), y.ravel())

self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
self.log("val_cos_sim", self.val_cos_sim_epoch, on_step=False, on_epoch=True, prog_bar=True)

return loss

def predict_step(self, batch, batch_idx):
x, y = batch
y_input = x[:, -1, :][:, None, :]
predictions = self.model(x, y_input)
return predictions

def configure_optimizers(self):
"""
Refs:
- https://stackoverflow.com/questions/65343377/adam-optimizer-with-warmup-on-pytorch
- https://github.com/Lightning-AI/lightning/issues/328#issuecomment-782845008
"""

optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.wd)

train_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.99)

if self.lr_warmup_steps == 0:
return [optimizer], [train_scheduler]

# using warmup scheduler
def warmup(current_step: int):
return 1 / (10 ** (float(self.lr_warmup_steps - current_step)))

warmup_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup)

scheduler = optim.lr_scheduler.SequentialLR(optimizer, [warmup_scheduler, train_scheduler],
[self.lr_warmup_steps])

return [optimizer], [
{
'scheduler': scheduler,
'interval': 'step',
'frequency': 1
}
]

13 changes: 7 additions & 6 deletions prediction/outcome_prediction/Transformer/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,22 @@ class MyEarlyStopping(Callback):
best_so_far = 0
last_improvement = 0

def __init__(self, step_limit=10):
def __init__(self, step_limit=10, metric='val_auroc'):
super().__init__()
self.step_limit = step_limit
self.metric = metric

def on_validation_end(self, trainer, pl_module):
logs = trainer.callback_metrics
val_auroc = logs['val_auroc'].item()
val_metric = logs[self.metric].item()

if val_auroc > self.best_so_far:
if val_metric > self.best_so_far:
self.last_improvement = 0
else:
self.last_improvement += 1

print(self.last_improvement)
trainer.should_stop = val_auroc < 0.75 * self.best_so_far or self.last_improvement > self.step_limit or \
(trainer.current_epoch > 10 and val_auroc < 0.55)
trainer.should_stop = val_metric < 0.75 * self.best_so_far or self.last_improvement > self.step_limit or \
(trainer.current_epoch > 10 and val_metric < 0.55)

self.best_so_far = max(val_auroc, self.best_so_far)
self.best_so_far = max(val_metric, self.best_so_far)
1 change: 1 addition & 0 deletions prediction/outcome_prediction/Transformer/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, version):
self._version = version

def log_metrics(self, metrics, step=None):
print(metrics)
self.metrics.append(metrics)

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import optuna
import json

from prediction.short_term_outcome_prediction.gridsearch_transformer import get_score
from prediction.short_term_outcome_prediction.gridsearch_transformer_encoder import get_score_encoder
from prediction.short_term_outcome_prediction.gridsearch_transformer_encoder_decoder import get_score_encoder_decoder
from prediction.short_term_outcome_prediction.timeseries_decomposition import prepare_subsequence_dataset


def subprocess_cluster_gridsearch(data_splits_path:str, output_folder:str, trial_name:str, gridsearch_config_path: dict,
use_gpu:bool=True,
use_gpu:bool=True, use_decoder:bool=False,
storage_pwd:str=None, storage_port:int=None, storage_host:str='localhost'):
# load config
with open(gridsearch_config_path, 'r') as f:
Expand All @@ -27,9 +28,14 @@ def subprocess_cluster_gridsearch(data_splits_path:str, output_folder:str, trial
splits = ch.load(path.join(data_splits_path))
all_datasets = [prepare_subsequence_dataset(x, use_gpu=use_gpu) for x in splits]

study.optimize(partial(get_score, ds=all_datasets, data_splits_path=data_splits_path, output_folder=output_folder,
if use_decoder:
study.optimize(partial(get_score_encoder_decoder, ds=all_datasets, data_splits_path=data_splits_path, output_folder=output_folder,
gridsearch_config=gridsearch_config,
use_gpu=use_gpu), n_trials=gridsearch_config['n_trials'])
else:
study.optimize(partial(get_score_encoder, ds=all_datasets, data_splits_path=data_splits_path, output_folder=output_folder,
gridsearch_config=gridsearch_config,
use_gpu=use_gpu), n_trials=gridsearch_config['n_trials'])


if __name__ == '__main__':
Expand All @@ -41,13 +47,16 @@ def subprocess_cluster_gridsearch(data_splits_path:str, output_folder:str, trial
parser.add_argument('-t', '--trial_name', type=str, required=True)
parser.add_argument('-c', '--gridsearch_config_path', type=str, required=True)
parser.add_argument('-g', '--use_gpu', type=str, required=False, default=1)
parser.add_argument('-dec', '--use_decoder', type=str, required=False, default=0)

parser.add_argument('-spwd', '--storage_pwd', type=str, required=False, default=None)
parser.add_argument('-sport', '--storage_port', type=int, required=False, default=None)
parser.add_argument('-shost', '--storage_host', type=str, required=False, default='localhost')

args = parser.parse_args()

use_gpu = (args.use_gpu == 1) | (args.use_gpu == '1') | (args.use_gpu == 'True')
use_decoder = (args.use_decoder == 1) | (args.use_decoder == '1') | (args.use_decoder == 'True')
subprocess_cluster_gridsearch(args.data_splits_path, args.output_folder, args.trial_name, args.gridsearch_config_path,
use_gpu=use_gpu,
use_gpu=use_gpu, use_decoder=args.use_decoder,
storage_pwd=args.storage_pwd, storage_port=args.storage_port, storage_host=args.storage_host)
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
"id": "initial_id",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-12T11:45:30.320074Z",
"start_time": "2024-08-12T11:45:30.316606Z"
"end_time": "2024-08-19T16:20:12.129902Z",
"start_time": "2024-08-19T16:20:12.127798Z"
}
},
"outputs": [],
Expand All @@ -24,8 +24,8 @@
"id": "801ca5fa6e5486ff",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-12T11:45:46.031267Z",
"start_time": "2024-08-12T11:45:46.027641Z"
"end_time": "2024-08-19T16:20:12.133910Z",
"start_time": "2024-08-19T16:20:12.131454Z"
}
},
"outputs": [],
Expand All @@ -40,8 +40,8 @@
"id": "38525202312f478",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-12T11:47:45.683875Z",
"start_time": "2024-08-12T11:47:45.628982Z"
"end_time": "2024-08-19T16:20:12.195459Z",
"start_time": "2024-08-19T16:20:12.135270Z"
}
},
"outputs": [],
Expand All @@ -53,6 +53,8 @@
" if file.endswith('.jsonl'):\n",
" temp_df = pd.read_json(os.path.join(root, file), \n",
" lines=True, dtype={'timestamp': 'object'}, convert_dates=False).drop(0)\n",
" # add file name as column\n",
" temp_df['file_name'] = file\n",
" gs_df = pd.concat([gs_df, temp_df], ignore_index=True)\n"
]
},
Expand All @@ -62,8 +64,8 @@
"id": "4168d945d95ab438",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-12T11:47:47.651657Z",
"start_time": "2024-08-12T11:47:47.626756Z"
"end_time": "2024-08-19T16:20:12.218765Z",
"start_time": "2024-08-19T16:20:12.196679Z"
}
},
"outputs": [],
Expand All @@ -77,8 +79,8 @@
"id": "76e926edc7bbaba",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-12T11:48:19.824527Z",
"start_time": "2024-08-12T11:48:19.807887Z"
"end_time": "2024-08-19T16:20:13.460060Z",
"start_time": "2024-08-19T16:20:13.444592Z"
}
},
"outputs": [],
Expand All @@ -88,14 +90,29 @@
"best_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e2ff3f473b3744",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-19T16:20:32.530714Z",
"start_time": "2024-08-19T16:20:32.525048Z"
}
},
"outputs": [],
"source": [
"# best_df.to_csv(os.path.join(output_dir, 'end_transformer_best_hyperparameters.csv'), index=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5575b1ac9754ede",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-12T11:48:42.067654Z",
"start_time": "2024-08-12T11:48:41.757260Z"
"end_time": "2024-08-18T14:36:18.690573Z",
"start_time": "2024-08-18T14:36:18.166203Z"
}
},
"outputs": [],
Expand All @@ -113,8 +130,8 @@
"id": "4ad2eb1e6da667eb",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-12T11:52:12.124761Z",
"start_time": "2024-08-12T11:52:09.989603Z"
"end_time": "2024-08-18T14:36:28.002346Z",
"start_time": "2024-08-18T14:36:18.693276Z"
}
},
"outputs": [],
Expand Down Expand Up @@ -147,7 +164,12 @@
"cell_type": "code",
"execution_count": null,
"id": "a8198a950e7b6945",
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-18T14:36:28.010684Z",
"start_time": "2024-08-18T14:36:28.006496Z"
}
},
"outputs": [],
"source": []
}
Expand Down
Loading

0 comments on commit 439094d

Please sign in to comment.