-
Notifications
You must be signed in to change notification settings - Fork 70
/
nn_train_test.py
186 lines (160 loc) · 5.39 KB
/
nn_train_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""This module is used for Nearest Neighbor based baselines.
Example usage:
$ python nn_train_test.py
--test_features ../data/forecasting_data_test.pkl
--train_features ../data/forecasting_data_train.pkl
--val_features ../data/forecasting_data_val.pkl ../../data/
--use_map --use_delta --n_neigh 3
--traj_save_path forecasted_trajectories/nn_none.pkl
"""
import argparse
import numpy as np
from typing import Any, Dict, List, Tuple, Union
import pandas as pd
import time
import ipdb
import utils.baseline_utils as baseline_utils
from utils.nn_utils import Regressor
def parse_arguments():
"""Parse command line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--train_features",
default="",
type=str,
help="path to the file which has train features.",
)
parser.add_argument(
"--val_features",
default="",
type=str,
help="path to the file which has val features.",
)
parser.add_argument(
"--test_features",
default="",
type=str,
help="path to the file which has test features.",
)
parser.add_argument("--test",
action="store_true",
help="Load the saved model and test")
parser.add_argument("--use_map",
action="store_true",
help="Use the map based features")
parser.add_argument("--use_social",
action="store_true",
help="Use social features")
parser.add_argument(
"--normalize",
action="store_true",
help="Normalize the trajectories if non-map baseline is used.",
)
parser.add_argument(
"--use_delta",
action="store_true",
help="Train on the change in position, instead of absolute position",
)
parser.add_argument(
"--joblib_batch_size",
default=100,
type=int,
help="Batch size for parallel computation",
)
parser.add_argument("--obs_len",
default=20,
type=int,
help="Observed length of the trajectory")
parser.add_argument("--pred_len",
default=30,
type=int,
help="Prediction Horizon")
parser.add_argument(
"--n_neigh",
default=1,
type=int,
help=
"Number of Nearest Neighbors to take. For map-based baselines, it is number of neighbors along each centerline.",
)
parser.add_argument(
"--model_path",
required=True,
type=str,
help=
"path to the pickle file where the model will be / has been saved.",
)
parser.add_argument(
"--traj_save_path",
required=True,
type=str,
help=
"path to the pickle file where forecasted trajectories will be saved.",
)
return parser.parse_args()
def perform_k_nn_experiments(
data_dict: Dict[str, Union[np.ndarray, pd.DataFrame, None]],
baseline_key: str) -> None:
"""Perform various experiments using K Nearest Neighbor Regressor.
Args:
data_dict (dict): Dictionary of train/val/test data
baseline_key: Key for obtaining features for the baseline
"""
args = parse_arguments()
# Get model object for the baseline
model = Regressor()
test_input = data_dict["test_input"]
test_output = data_dict["test_output"]
test_helpers = data_dict["test_helpers"]
train_input = data_dict["train_input"]
train_output = data_dict["train_output"]
train_helpers = data_dict["train_helpers"]
val_input = data_dict["val_input"]
val_output = data_dict["val_output"]
val_helpers = data_dict["val_helpers"]
# Merge train and val splits and use K-fold cross validation instead
train_val_input = np.concatenate((train_input, val_input))
train_val_output = np.concatenate((train_output, val_output))
train_val_helpers = np.concatenate([train_helpers, val_helpers])
if args.use_map:
print("#### Training Nearest Neighbor in NT frame ###")
model.train_and_infer_map(
train_val_input,
train_val_output,
test_helpers,
len(baseline_utils.BASELINE_INPUT_FEATURES[baseline_key]),
args,
)
else:
print("#### Training Nearest Neighbor in absolute map frame ###")
model.train_and_infer_absolute(
train_val_input,
train_val_output,
test_input,
test_helpers,
len(baseline_utils.BASELINE_INPUT_FEATURES[baseline_key]),
args,
)
def main():
"""Load data and perform experiments."""
args = parse_arguments()
if not baseline_utils.validate_args(args):
return
np.random.seed(100)
# Get features
if args.use_map and args.use_social:
baseline_key = "map_social"
elif args.use_map:
baseline_key = "map"
elif args.use_social:
baseline_key = "social"
else:
baseline_key = "none"
# Get data
data_dict = baseline_utils.get_data(args, baseline_key)
# Perform experiments
start = time.time()
perform_k_nn_experiments(data_dict, baseline_key)
end = time.time()
print(f"Completed experiment in {(end-start)/60.0} mins")
if __name__ == "__main__":
main()