-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclean.py
111 lines (97 loc) · 4.14 KB
/
clean.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import logging
from pathlib import Path
from collections import defaultdict
import yaml
import hydra
import omegaconf
import numpy as np
import tsbench
from tsbench.utils import evaluate_trajectories, interpolate_trajectories
logger = logging.getLogger(__name__)
def clean_simplified_trajectory_with_gt(simplified_trajectories, gt_tids, gt_trajectories):
selected_timestamps = defaultdict(list)
for tid, traj in simplified_trajectories.items():
# get all the timestamps in the simplified traj
timestamps = traj[:, 0]
# get all the correct track id
t_gt_tids = gt_tids[tid]
for timestamp in timestamps:
gt_tid = t_gt_tids[t_gt_tids[:,0]==timestamp]
assert len(gt_tid) == 1
gt_tid = gt_tid[0,1]
selected_timestamps[gt_tid].append(timestamp)
corrected_simplied_trajectory = {}
for tid in gt_trajectories.keys():
timstamps = selected_timestamps[tid]
gt_traj = gt_trajectories[tid]
mask = np.isin(gt_traj[:,0],timstamps)
# make sure the first and last objects are covered
mask[0] = True
mask[-1] = True
corrected_simplied_trajectory[tid] = gt_traj[mask]
return corrected_simplied_trajectory
@hydra.main(config_path="configs", config_name="clean")
def main(cfg: omegaconf.dictconfig.DictConfig) -> None:
logger.info(f"Configuration Parameters:\n {omegaconf.OmegaConf.to_yaml(cfg)}")
# Instantiate a search algorithm class
algo = tsbench.ALGO_REGISTRY.get(cfg.algo.name)()
# Instantiate a dataset class
params = cfg.dataset.params if "params" in cfg.dataset else {}
dataset = tsbench.DATASET_REGISTRY.get(name=cfg.dataset.name)(
path=hydra.utils.to_absolute_path(cfg.dataset.path), **params
)
params = cfg.gt_dataset.params if "params" in cfg.gt_dataset else {}
gt_dataset = tsbench.DATASET_REGISTRY.get(name=cfg.gt_dataset.name)(
path=hydra.utils.to_absolute_path(cfg.gt_dataset.path), **params
)
for kid, key in enumerate(cfg.dataset.key):
trajectories = dataset.get_trajectories(key)
gt_tids = dataset.get_trajectories("gt_tids")
gt_trajectories = gt_dataset.get_trajectories(cfg.gt_dataset.key[kid])
res = []
for param in cfg.algo.params:
simplified_trajectories, runtime_per_query = algo.simplify(
trajectories, **param
)
simplified_trajectories = clean_simplified_trajectory_with_gt(simplified_trajectories, gt_tids, gt_trajectories)
metrics = evaluate_trajectories(gt_trajectories, simplified_trajectories)
metrics["runtime"] = runtime_per_query
metrics = {
"param": dict(param),
"metrics": {
k: float(np.mean(list(v.values()))) for k, v in metrics.items()
},
}
res.append(metrics)
logger.info(metrics)
# (0) Dump the interpolate trajectories
interp_trajectories = interpolate_trajectories(
gt_trajectories, simplified_trajectories
)
out = (
Path(hydra.utils.to_absolute_path(cfg.output))
/ cfg.dataset.name
/ Path(cfg.dataset.path).name
/ cfg.algo.name
/ "_".join([str(v) for v in param.values()])
)
out.mkdir(exist_ok=True, parents=True)
gt_dataset.dump_data(out, interp_trajectories)
# (1) Save the result on the local log directory
with open(f"result_{key}.yaml", "wt") as f:
yaml.dump(res, f)
# (2) And the output directory.
out = (
Path(hydra.utils.to_absolute_path(cfg.output))
/ cfg.dataset.name
/ Path(cfg.dataset.path).name
/ cfg.algo.name
/ f"result_{key}.yaml"
)
out.parent.mkdir(
exist_ok=True, parents=True
) # Make sure the parent directory exists
with out.open("wt") as f:
yaml.dump(res, f)
if __name__ == "__main__":
main()