-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathpaper_sequences.py
107 lines (98 loc) · 4.16 KB
/
paper_sequences.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import numpy as np
import pickle
from matplotlib import pyplot as plt
from navrep.tools.wdataset import WorldModelDataset
from strictfire import StrictFire
from tqdm import tqdm
from pyniel.python_tools.path_tools import make_dir_if_not_exists
from navdreams.auto_debug import enable_auto_debug
def find_files(dir_):
files = []
dir_ = os.path.expanduser(dir_)
for f in os.listdir(dir_):
if "_sequence_" in f and f.endswith(".pkl"):
files.append(os.path.join(dir_, f))
return files
def generate_paper_sequences(dir_, dataset_dir, sequence_length):
seq_loader = WorldModelDataset(dataset_dir, sequence_length, lidar_mode="images",
channel_first=False, as_torch_tensors=False, file_limit=None)
print("{} sequences available".format(len(seq_loader)))
print("Saving sequences")
i = 0
for idx in tqdm(np.random.permutation(range(len(seq_loader)))):
(x, a, y, x_rs, y_rs, dones) = seq_loader[idx]
if np.any(dones):
continue
real_sequence = [dict(obs=x[i], state=x_rs[i], action=a[i], done=dones[i])
for i in range(sequence_length)]
path = os.path.join(dir_, "{}_sequence_{}.pkl".format(sequence_length, i))
pickle.dump(real_sequence, open(path, "wb"))
_ = pickle.load(open(path, "rb"))
i += 1
def load_paper_sequences(examples, n_examples, dataset_dir, sequence_length,
dir_="~/navdreams_data/wm_test_data/sequences/"):
dir_ = os.path.expanduser(dir_)
make_dir_if_not_exists(dir_)
# list all data files
files = find_files(dir_)
if len(files) < n_examples:
print("No sequences found. Generate?")
if not input("y/n: ").lower().startswith("y"):
raise ValueError("Debug: Program end")
return
generate_paper_sequences(dir_, dataset_dir, sequence_length)
files = find_files(dir_)
file_dict = {}
for file in files:
idx = int(file.split("_")[-1].split(".")[0])
file_dict[idx] = file
example_sequences = {examples[i]: None for i in range(n_examples)}
for idx in example_sequences:
if idx not in file_dict:
raise ValueError("No sequence found for example {}".format(idx))
file = file_dict[idx]
loaded_sequence = pickle.load(open(file, "rb"))
example_sequences[idx] = loaded_sequence
return example_sequences
def hide_axes_but_keep_ylabel(ax):
ax.set_xticks([])
ax.set_yticks([])
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
if False:
ax.set_axis_off()
def plot_sequence(sequence, skip=1, title=""):
sequence_length = len(sequence)
# images plot
n_rows = 1
n_cols = sequence_length // (1 + skip)
fig, axes = plt.subplots(n_rows, n_cols, num="dream",
figsize=(22, 14), dpi=100)
axes = np.array(axes).reshape((-1, n_cols))
n = 0
for i in range(sequence_length):
axes[n_rows*n, i // (1 + skip)].imshow(sequence[i]['obs'])
axes[n_rows*n, -1].set_ylabel("GT", rotation=0, labelpad=50)
axes[n_rows*n, -1].yaxis.set_label_position("right")
for ax in np.array(axes).flatten():
hide_axes_but_keep_ylabel(ax)
plt.subplots_adjust(wspace=0.1)
plt.title(title)
plt.show()
def main():
N = 100
dataset_dir = [os.path.expanduser("~/navdreams_data/wm_test_data/datasets/V/navrep3dalt"),
os.path.expanduser("~/navdreams_data/wm_test_data/datasets/V/navrep3dcity"),
os.path.expanduser("~/navdreams_data/wm_test_data/datasets/V/navrep3doffice"),
os.path.expanduser("~/navdreams_data/wm_test_data/datasets/V/navrep3dasl"),
os.path.expanduser("~/navdreams_data/wm_experiments/datasets/V/rosbag")]
example_sequences = load_paper_sequences(range(N), N, dataset_dir, 64)
for idx in example_sequences:
sequence = example_sequences[idx]
plot_sequence(sequence, title=str(idx) + "_sequence_64.pkl")
if __name__ == "__main__":
enable_auto_debug()
StrictFire(main)