Skip to content

Commit

Permalink
v0.5 to main (#10)
Browse files Browse the repository at this point in the history
* v0.5 (#9)

* update idql configs

* update awr configs

* update dipo configs

* update qsm configs

* update dqm configs

* update project version to 0.5.0
  • Loading branch information
allenzren authored Oct 7, 2024
1 parent dd14c58 commit e0842e7
Show file tree
Hide file tree
Showing 267 changed files with 6,759 additions and 1,635 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ checkpoints/
out/
err/
*.pkl
*.sh

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
161 changes: 158 additions & 3 deletions agent/dataset/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@
import logging
import pickle
import random
from tqdm import tqdm

log = logging.getLogger(__name__)

Batch = namedtuple("Batch", "actions conditions")
Transition = namedtuple("Transition", "actions conditions rewards dones")
TransitionWithReturn = namedtuple(
"Transition", "actions conditions rewards dones reward_to_gos"
)


class StitchedSequenceDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -49,6 +54,8 @@ def __init__(
self.img_cond_steps = img_cond_steps
self.device = device
self.use_img = use_img
self.max_n_episodes = max_n_episodes
self.dataset_path = dataset_path

# Load dataset to device specified
if dataset_path.endswith(".npz"):
Expand Down Expand Up @@ -87,7 +94,7 @@ def __getitem__(self, idx):
"""
start, num_before_start = self.indices[idx]
end = start + self.horizon_steps
states = self.states[(start - num_before_start) : end]
states = self.states[(start - num_before_start) : (start + 1)]
actions = self.actions[start:end]
states = torch.stack(
[
Expand Down Expand Up @@ -116,9 +123,9 @@ def make_indices(self, traj_lengths, horizon_steps):
indices = []
cur_traj_index = 0
for traj_length in traj_lengths:
max_start = cur_traj_index + traj_length - horizon_steps + 1
max_start = cur_traj_index + traj_length - horizon_steps
indices += [
(i, i - cur_traj_index) for i in range(cur_traj_index, max_start)
(i, i - cur_traj_index) for i in range(cur_traj_index, max_start + 1)
]
cur_traj_index += traj_length
return indices
Expand All @@ -135,3 +142,151 @@ def set_train_val_split(self, train_split):

def __len__(self):
return len(self.indices)


class StitchedSequenceQLearningDataset(StitchedSequenceDataset):
"""
Extends StitchedSequenceDataset to include rewards and dones for Q learning
Do not load the last step of **truncated** episodes since we do not have the correct next state for the final step of each episode. Truncation can be determined by terminal=False but end of episode.
"""

def __init__(
self,
dataset_path,
max_n_episodes=10000,
discount_factor=1.0,
device="cuda:0",
get_mc_return=False,
**kwargs,
):
if dataset_path.endswith(".npz"):
dataset = np.load(dataset_path, allow_pickle=False)
elif dataset_path.endswith(".pkl"):
with open(dataset_path, "rb") as f:
dataset = pickle.load(f)
else:
raise ValueError(f"Unsupported file format: {dataset_path}")
traj_lengths = dataset["traj_lengths"][:max_n_episodes]
total_num_steps = np.sum(traj_lengths)

# discount factor
self.discount_factor = discount_factor

# rewards and dones(terminals)
self.rewards = (
torch.from_numpy(dataset["rewards"][:total_num_steps]).float().to(device)
)
log.info(f"Rewards shape/type: {self.rewards.shape, self.rewards.dtype}")
self.dones = (
torch.from_numpy(dataset["terminals"][:total_num_steps]).to(device).float()
)
log.info(f"Dones shape/type: {self.dones.shape, self.dones.dtype}")

super().__init__(
dataset_path=dataset_path,
max_n_episodes=max_n_episodes,
device=device,
**kwargs,
)
log.info(f"Total number of transitions using: {len(self)}")

# compute discounted reward-to-go for each trajectory
self.get_mc_return = get_mc_return
if get_mc_return:
self.reward_to_go = torch.zeros_like(self.rewards)
cumulative_traj_length = np.cumsum(traj_lengths)
prev_traj_length = 0
for i, traj_length in tqdm(
enumerate(cumulative_traj_length), desc="Computing reward-to-go"
):
traj_rewards = self.rewards[prev_traj_length:traj_length]
returns = torch.zeros_like(traj_rewards)
prev_return = 0
for t in range(len(traj_rewards)):
returns[-t - 1] = (
traj_rewards[-t - 1] + self.discount_factor * prev_return
)
prev_return = returns[-t - 1]
self.reward_to_go[prev_traj_length:traj_length] = returns
prev_traj_length = traj_length
log.info(f"Computed reward-to-go for each trajectory.")

def make_indices(self, traj_lengths, horizon_steps):
"""
skip last step of truncated episodes
"""
num_skip = 0
indices = []
cur_traj_index = 0
for traj_length in traj_lengths:
max_start = cur_traj_index + traj_length - horizon_steps
if not self.dones[cur_traj_index + traj_length - 1]: # truncation
max_start -= 1
num_skip += 1
indices += [
(i, i - cur_traj_index) for i in range(cur_traj_index, max_start + 1)
]
cur_traj_index += traj_length
log.info(f"Number of transitions skipped due to truncation: {num_skip}")
return indices

def __getitem__(self, idx):
start, num_before_start = self.indices[idx]
end = start + self.horizon_steps
states = self.states[(start - num_before_start) : (start + 1)]
actions = self.actions[start:end]
rewards = self.rewards[start : (start + 1)]
dones = self.dones[start : (start + 1)]

# Account for action horizon
if idx < len(self.indices) - self.horizon_steps:
next_states = self.states[
(start - num_before_start + self.horizon_steps) : start
+ 1
+ self.horizon_steps
] # even if this uses the first state(s) of the next episode, done=True will prevent bootstrapping. We have already filtered out cases where done=False but end of episode (truncation).
else:
# prevents indexing error, but ignored since done=True
next_states = torch.zeros_like(states)

# stack obs history
states = torch.stack(
[
states[max(num_before_start - t, 0)]
for t in reversed(range(self.cond_steps))
]
) # more recent is at the end
next_states = torch.stack(
[
next_states[max(num_before_start - t, 0)]
for t in reversed(range(self.cond_steps))
]
) # more recent is at the end
conditions = {"state": states, "next_state": next_states}
if self.use_img:
images = self.images[(start - num_before_start) : end]
images = torch.stack(
[
images[max(num_before_start - t, 0)]
for t in reversed(range(self.img_cond_steps))
]
)
conditions["rgb"] = images
if self.get_mc_return:
reward_to_gos = self.reward_to_go[start : (start + 1)]
batch = TransitionWithReturn(
actions,
conditions,
rewards,
dones,
reward_to_gos,
)
else:
batch = Transition(
actions,
conditions,
rewards,
dones,
)
return batch
12 changes: 8 additions & 4 deletions agent/eval/eval_diffusion_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def run(self):
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1
reward_trajs = np.empty((0, self.n_envs))
reward_trajs = np.zeros((self.n_steps, self.n_envs))

# Collect a set of trajectories from env
for step in range(self.n_steps):
Expand All @@ -57,9 +57,13 @@ def run(self):
action_venv = output_venv[:, : self.act_steps]

# Apply multi-step action
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(action_venv)
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
firsts_trajs[step + 1] = done_venv
obs_venv, reward_venv, terminated_venv, truncated_venv, info_venv = (
self.venv.step(action_venv)
)
reward_trajs[step] = reward_venv
firsts_trajs[step + 1] = terminated_venv | truncated_venv

# update for next step
prev_obs_venv = obs_venv

# Summarize episode reward --- this needs to be handled differently depending on whether the environment is reset after each iteration. Only count episodes that finish within the iteration.
Expand Down
12 changes: 8 additions & 4 deletions agent/eval/eval_diffusion_img_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def run(self):
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1
reward_trajs = np.empty((0, self.n_envs))
reward_trajs = np.zeros((self.n_steps, self.n_envs))

# Collect a set of trajectories from env
for step in range(self.n_steps):
Expand All @@ -60,9 +60,13 @@ def run(self):
action_venv = output_venv[:, : self.act_steps]

# Apply multi-step action
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(action_venv)
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
firsts_trajs[step + 1] = done_venv
obs_venv, reward_venv, terminated_venv, truncated_venv, info_venv = (
self.venv.step(action_venv)
)
reward_trajs[step] = reward_venv
firsts_trajs[step + 1] = terminated_venv | truncated_venv

# update for next step
prev_obs_venv = obs_venv

# Summarize episode reward --- this needs to be handled differently depending on whether the environment is reset after each iteration. Only count episodes that finish within the iteration.
Expand Down
12 changes: 8 additions & 4 deletions agent/eval/eval_gaussian_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def run(self):
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1
reward_trajs = np.empty((0, self.n_envs))
reward_trajs = np.zeros((self.n_steps, self.n_envs))

# Collect a set of trajectories from env
for step in range(self.n_steps):
Expand All @@ -55,9 +55,13 @@ def run(self):
action_venv = output_venv[:, : self.act_steps]

# Apply multi-step action
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(action_venv)
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
firsts_trajs[step + 1] = done_venv
obs_venv, reward_venv, terminated_venv, truncated_venv, info_venv = (
self.venv.step(action_venv)
)
reward_trajs[step] = reward_venv
firsts_trajs[step + 1] = terminated_venv | truncated_venv

# update for next step
prev_obs_venv = obs_venv

# Summarize episode reward --- this needs to be handled differently depending on whether the environment is reset after each iteration. Only count episodes that finish within the iteration.
Expand Down
12 changes: 8 additions & 4 deletions agent/eval/eval_gaussian_img_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def run(self):
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1
reward_trajs = np.empty((0, self.n_envs))
reward_trajs = np.zeros((self.n_steps, self.n_envs))

# Collect a set of trajectories from env
for step in range(self.n_steps):
Expand All @@ -58,9 +58,13 @@ def run(self):
action_venv = output_venv[:, : self.act_steps]

# Apply multi-step action
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(action_venv)
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
firsts_trajs[step + 1] = done_venv
obs_venv, reward_venv, terminated_venv, truncated_venv, info_venv = (
self.venv.step(action_venv)
)
reward_trajs[step] = reward_venv
firsts_trajs[step + 1] = terminated_venv | truncated_venv

# update for next step
prev_obs_venv = obs_venv

# Summarize episode reward --- this needs to be handled differently depending on whether the environment is reset after each iteration. Only count episodes that finish within the iteration.
Expand Down
Loading

0 comments on commit e0842e7

Please sign in to comment.