-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathsequence.py
292 lines (261 loc) · 11 KB
/
sequence.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
"""
Pre-training data loader. Modified from https://github.com/jannerm/diffuser/blob/main/diffuser/datasets/sequence.py
No normalization is applied here --- we always normalize the data when pre-processing it with a different script, and the normalization info is also used in RL fine-tuning.
"""
from collections import namedtuple
import numpy as np
import torch
import logging
import pickle
import random
from tqdm import tqdm
log = logging.getLogger(__name__)
Batch = namedtuple("Batch", "actions conditions")
Transition = namedtuple("Transition", "actions conditions rewards dones")
TransitionWithReturn = namedtuple(
"Transition", "actions conditions rewards dones reward_to_gos"
)
class StitchedSequenceDataset(torch.utils.data.Dataset):
"""
Load stitched trajectories of states/actions/images, and 1-D array of traj_lengths, from npz or pkl file.
Use the first max_n_episodes episodes (instead of random sampling)
Example:
states: [----------traj 1----------][---------traj 2----------] ... [---------traj N----------]
Episode IDs (determined based on traj_lengths): [---------- 1 ----------][---------- 2 ---------] ... [---------- N ---------]
Each sample is a namedtuple of (1) chunked actions and (2) a list (obs timesteps) of dictionary with keys states and images.
"""
def __init__(
self,
dataset_path,
horizon_steps=64,
cond_steps=1,
img_cond_steps=1,
max_n_episodes=10000,
use_img=False,
device="cuda:0",
):
assert (
img_cond_steps <= cond_steps
), "consider using more cond_steps than img_cond_steps"
self.horizon_steps = horizon_steps
self.cond_steps = cond_steps # states (proprio, etc.)
self.img_cond_steps = img_cond_steps
self.device = device
self.use_img = use_img
self.max_n_episodes = max_n_episodes
self.dataset_path = dataset_path
# Load dataset to device specified
if dataset_path.endswith(".npz"):
dataset = np.load(dataset_path, allow_pickle=False) # only np arrays
elif dataset_path.endswith(".pkl"):
with open(dataset_path, "rb") as f:
dataset = pickle.load(f)
else:
raise ValueError(f"Unsupported file format: {dataset_path}")
traj_lengths = dataset["traj_lengths"][:max_n_episodes] # 1-D array
total_num_steps = np.sum(traj_lengths)
# Set up indices for sampling
self.indices = self.make_indices(traj_lengths, horizon_steps)
# Extract states and actions up to max_n_episodes
self.states = (
torch.from_numpy(dataset["states"][:total_num_steps]).float().to(device)
) # (total_num_steps, obs_dim)
self.actions = (
torch.from_numpy(dataset["actions"][:total_num_steps]).float().to(device)
) # (total_num_steps, action_dim)
log.info(f"Loaded dataset from {dataset_path}")
log.info(f"Number of episodes: {min(max_n_episodes, len(traj_lengths))}")
log.info(f"States shape/type: {self.states.shape, self.states.dtype}")
log.info(f"Actions shape/type: {self.actions.shape, self.actions.dtype}")
if self.use_img:
self.images = torch.from_numpy(dataset["images"][:total_num_steps]).to(
device
) # (total_num_steps, C, H, W)
log.info(f"Images shape/type: {self.images.shape, self.images.dtype}")
def __getitem__(self, idx):
"""
repeat states/images if using history observation at the beginning of the episode
"""
start, num_before_start = self.indices[idx]
end = start + self.horizon_steps
states = self.states[(start - num_before_start) : (start + 1)]
actions = self.actions[start:end]
states = torch.stack(
[
states[max(num_before_start - t, 0)]
for t in reversed(range(self.cond_steps))
]
) # more recent is at the end
conditions = {"state": states}
if self.use_img:
images = self.images[(start - num_before_start) : end]
images = torch.stack(
[
images[max(num_before_start - t, 0)]
for t in reversed(range(self.img_cond_steps))
]
)
conditions["rgb"] = images
batch = Batch(actions, conditions)
return batch
def make_indices(self, traj_lengths, horizon_steps):
"""
makes indices for sampling from dataset;
each index maps to a datapoint, also save the number of steps before it within the same trajectory
"""
indices = []
cur_traj_index = 0
for traj_length in traj_lengths:
max_start = cur_traj_index + traj_length - horizon_steps
indices += [
(i, i - cur_traj_index) for i in range(cur_traj_index, max_start + 1)
]
cur_traj_index += traj_length
return indices
def set_train_val_split(self, train_split):
"""
Not doing validation right now
"""
num_train = int(len(self.indices) * train_split)
train_indices = random.sample(self.indices, num_train)
val_indices = [i for i in range(len(self.indices)) if i not in train_indices]
self.indices = train_indices
return val_indices
def __len__(self):
return len(self.indices)
class StitchedSequenceQLearningDataset(StitchedSequenceDataset):
"""
Extends StitchedSequenceDataset to include rewards and dones for Q learning
Do not load the last step of **truncated** episodes since we do not have the correct next state for the final step of each episode. Truncation can be determined by terminal=False but end of episode.
"""
def __init__(
self,
dataset_path,
max_n_episodes=10000,
discount_factor=1.0,
device="cuda:0",
get_mc_return=False,
**kwargs,
):
if dataset_path.endswith(".npz"):
dataset = np.load(dataset_path, allow_pickle=False)
elif dataset_path.endswith(".pkl"):
with open(dataset_path, "rb") as f:
dataset = pickle.load(f)
else:
raise ValueError(f"Unsupported file format: {dataset_path}")
traj_lengths = dataset["traj_lengths"][:max_n_episodes]
total_num_steps = np.sum(traj_lengths)
# discount factor
self.discount_factor = discount_factor
# rewards and dones(terminals)
self.rewards = (
torch.from_numpy(dataset["rewards"][:total_num_steps]).float().to(device)
)
log.info(f"Rewards shape/type: {self.rewards.shape, self.rewards.dtype}")
self.dones = (
torch.from_numpy(dataset["terminals"][:total_num_steps]).to(device).float()
)
log.info(f"Dones shape/type: {self.dones.shape, self.dones.dtype}")
super().__init__(
dataset_path=dataset_path,
max_n_episodes=max_n_episodes,
device=device,
**kwargs,
)
log.info(f"Total number of transitions using: {len(self)}")
# compute discounted reward-to-go for each trajectory
self.get_mc_return = get_mc_return
if get_mc_return:
self.reward_to_go = torch.zeros_like(self.rewards)
cumulative_traj_length = np.cumsum(traj_lengths)
prev_traj_length = 0
for i, traj_length in tqdm(
enumerate(cumulative_traj_length), desc="Computing reward-to-go"
):
traj_rewards = self.rewards[prev_traj_length:traj_length]
returns = torch.zeros_like(traj_rewards)
prev_return = 0
for t in range(len(traj_rewards)):
returns[-t - 1] = (
traj_rewards[-t - 1] + self.discount_factor * prev_return
)
prev_return = returns[-t - 1]
self.reward_to_go[prev_traj_length:traj_length] = returns
prev_traj_length = traj_length
log.info(f"Computed reward-to-go for each trajectory.")
def make_indices(self, traj_lengths, horizon_steps):
"""
skip last step of truncated episodes
"""
num_skip = 0
indices = []
cur_traj_index = 0
for traj_length in traj_lengths:
max_start = cur_traj_index + traj_length - horizon_steps
if not self.dones[cur_traj_index + traj_length - 1]: # truncation
max_start -= 1
num_skip += 1
indices += [
(i, i - cur_traj_index) for i in range(cur_traj_index, max_start + 1)
]
cur_traj_index += traj_length
log.info(f"Number of transitions skipped due to truncation: {num_skip}")
return indices
def __getitem__(self, idx):
start, num_before_start = self.indices[idx]
end = start + self.horizon_steps
states = self.states[(start - num_before_start) : (start + 1)]
actions = self.actions[start:end]
rewards = self.rewards[start : (start + 1)]
dones = self.dones[start : (start + 1)]
# Account for action horizon
if idx < len(self.indices) - self.horizon_steps:
next_states = self.states[
(start - num_before_start + self.horizon_steps) : start
+ 1
+ self.horizon_steps
] # even if this uses the first state(s) of the next episode, done=True will prevent bootstrapping. We have already filtered out cases where done=False but end of episode (truncation).
else:
# prevents indexing error, but ignored since done=True
next_states = torch.zeros_like(states)
# stack obs history
states = torch.stack(
[
states[max(num_before_start - t, 0)]
for t in reversed(range(self.cond_steps))
]
) # more recent is at the end
next_states = torch.stack(
[
next_states[max(num_before_start - t, 0)]
for t in reversed(range(self.cond_steps))
]
) # more recent is at the end
conditions = {"state": states, "next_state": next_states}
if self.use_img:
images = self.images[(start - num_before_start) : end]
images = torch.stack(
[
images[max(num_before_start - t, 0)]
for t in reversed(range(self.img_cond_steps))
]
)
conditions["rgb"] = images
if self.get_mc_return:
reward_to_gos = self.reward_to_go[start : (start + 1)]
batch = TransitionWithReturn(
actions,
conditions,
rewards,
dones,
reward_to_gos,
)
else:
batch = Transition(
actions,
conditions,
rewards,
dones,
)
return batch