-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
34 lines (25 loc) · 944 Bytes
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np
import torch
import random
import get_data
import util
from hyperparams import all_hyperparams
from data_settings import all_settings
###################################################################################################
'''
main block
'''
if __name__ == '__main__':
#random seed
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
dataset_name = 'bp_traj' #options: no_bp, bp_stats, bp_traj
approach = 'baseline1' #standard lstm
data_params = all_settings[dataset_name]
data_package = get_data.get_dataset(dataset_name, data_params)
hyperparams = all_hyperparams[dataset_name][approach]
data_params['n_feats'] = data_package[0][0][0].shape[1]
mod, results = util.get_model(data_package, approach, data_params, hyperparams)
if dataset_name == 'bp_traj':
util.analyze_mod(approach, hyperparams, data_params, data_package)