Skip to content

Commit

Permalink
vectorize
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Dec 6, 2024
1 parent 5ce1fdc commit 6f13cff
Showing 1 changed file with 88 additions and 20 deletions.
108 changes: 88 additions & 20 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,26 +451,94 @@ def reverse_backward_trajectories(trajectories: Trajectories) -> Trajectories:
)
)

new_when_is_done = trajectories.when_is_done + 1
new_states = trajectories.env.sf.repeat(
new_when_is_done.max() + 1, len(trajectories), 1
# Compute sequence lengths and maximum length
seq_lengths = trajectories.when_is_done # shape (n_trajectories,)
max_len = seq_lengths.max().item()

# Get actions and states
actions = (
trajectories.actions.tensor
) # shape (max_len, n_trajectories *action_dim)
states = (
trajectories.states.tensor
) # shape (max_len + 1, n_trajectories, *state_dim)

# Initialize new actions and states
new_actions = torch.full(
(max_len + 1, len(trajectories), *trajectories.actions.action_shape), -1
).to(
actions
) # shape (max_len + 1, n_trajectories, *action_dim)
new_states = trajectories.env.sf.repeat(max_len + 2, len(trajectories), 1).to(
states
) # shape (max_len + 2, n_trajectories, *state_dim)

# Create helper indices and masks
idx = (
torch.arange(max_len)
.unsqueeze(1)
.expand(-1, len(trajectories))
.to(seq_lengths)
)

# FIXME: Can we vectorize this?
# FIXME: Also, loop over batch or sequence?
for i in range(len(trajectories)):
new_actions[trajectories.when_is_done[i], i] = (
trajectories.env.n_actions - 1
)
new_actions[
: trajectories.when_is_done[i], i
] = trajectories.actions.tensor[: trajectories.when_is_done[i], i].flip(0)

new_states[
: trajectories.when_is_done[i] + 1, i
] = trajectories.states.tensor[: trajectories.when_is_done[i] + 1, i].flip(
0
)
rev_idx = seq_lengths - 1 - idx # shape (max_len, n_trajectories)
mask = rev_idx >= 0 # shape (max_len, n_trajectories)
rev_idx[:, 1:] += seq_lengths.cumsum(0)[:-1]

# Transpose for easier indexing
actions = actions.transpose(
0, 1
) # shape (n_trajectories, max_len, *action_dim)
new_actions = new_actions.transpose(
0, 1
) # shape (n_trajectories, max_len + 1, *action_dim)
states = states.transpose(
0, 1
) # shape (n_trajectories, max_len + 1, *state_dim)
new_states = new_states.transpose(
0, 1
) # shape (n_trajectories, max_len + 2, *state_dim)
rev_idx = rev_idx.transpose(0, 1)
mask = mask.transpose(0, 1)

# Assign reversed actions to new_actions
new_actions[:, :-1][mask] = actions[mask][rev_idx[mask]]
new_actions[torch.arange(len(trajectories)), seq_lengths] = (
trajectories.env.n_actions - 1
) # FIXME: This can be problematic if action_dim != 1 (e.g. continuous actions)

# Assign reversed states to new_states
assert torch.all(states[:, -1] == trajectories.env.s0), "Last state must be s0"
new_states[:, 0] = trajectories.env.s0
new_states[:, 1:-1][mask] = states[:, :-1][mask][rev_idx[mask]]

# Transpose back
new_actions = new_actions.transpose(
0, 1
) # shape (max_len + 1, n_trajectories, *action_dim)
new_states = new_states.transpose(
0, 1
) # shape (max_len + 2, n_trajectories, *state_dim)

# TODO: Add below into the test suite to ensure correctness
# new_actions2 = torch.full((max_len + 1, len(trajectories), *trajectories.actions.action_shape), -1).to(actions)
# new_states2 = trajectories.env.sf.repeat(max_len + 2, len(trajectories), 1).to(states) # shape (max_len + 2, n_trajectories, *state_dim)

# for i in range(len(trajectories)):
# new_actions2[trajectories.when_is_done[i], i] = (
# trajectories.env.n_actions - 1
# )
# new_actions2[
# : trajectories.when_is_done[i], i
# ] = trajectories.actions.tensor[: trajectories.when_is_done[i], i].flip(0)

# new_states2[
# : trajectories.when_is_done[i] + 1, i
# ] = trajectories.states.tensor[: trajectories.when_is_done[i] + 1, i].flip(
# 0
# )

# assert torch.all(new_actions == new_actions2)
# assert torch.all(new_states == new_states2)

trajectories_states = trajectories.env.states_from_tensor(new_states)
trajectories_actions = trajectories.env.actions_from_tensor(new_actions)
Expand All @@ -480,7 +548,7 @@ def reverse_backward_trajectories(trajectories: Trajectories) -> Trajectories:
states=trajectories_states,
conditioning=trajectories.conditioning,
actions=trajectories_actions,
when_is_done=new_when_is_done,
when_is_done=trajectories.when_is_done + 1,
is_backward=False,
log_rewards=trajectories.log_rewards,
log_probs=None, # We can't simply pass the trajectories.log_probs
Expand Down

0 comments on commit 6f13cff

Please sign in to comment.