diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 6c8e9115..825d7f1c 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -451,26 +451,94 @@ def reverse_backward_trajectories(trajectories: Trajectories) -> Trajectories: ) ) - new_when_is_done = trajectories.when_is_done + 1 - new_states = trajectories.env.sf.repeat( - new_when_is_done.max() + 1, len(trajectories), 1 + # Compute sequence lengths and maximum length + seq_lengths = trajectories.when_is_done # shape (n_trajectories,) + max_len = seq_lengths.max().item() + + # Get actions and states + actions = ( + trajectories.actions.tensor + ) # shape (max_len, n_trajectories *action_dim) + states = ( + trajectories.states.tensor + ) # shape (max_len + 1, n_trajectories, *state_dim) + + # Initialize new actions and states + new_actions = torch.full( + (max_len + 1, len(trajectories), *trajectories.actions.action_shape), -1 + ).to( + actions + ) # shape (max_len + 1, n_trajectories, *action_dim) + new_states = trajectories.env.sf.repeat(max_len + 2, len(trajectories), 1).to( + states + ) # shape (max_len + 2, n_trajectories, *state_dim) + + # Create helper indices and masks + idx = ( + torch.arange(max_len) + .unsqueeze(1) + .expand(-1, len(trajectories)) + .to(seq_lengths) ) - - # FIXME: Can we vectorize this? - # FIXME: Also, loop over batch or sequence? - for i in range(len(trajectories)): - new_actions[trajectories.when_is_done[i], i] = ( - trajectories.env.n_actions - 1 - ) - new_actions[ - : trajectories.when_is_done[i], i - ] = trajectories.actions.tensor[: trajectories.when_is_done[i], i].flip(0) - - new_states[ - : trajectories.when_is_done[i] + 1, i - ] = trajectories.states.tensor[: trajectories.when_is_done[i] + 1, i].flip( - 0 - ) + rev_idx = seq_lengths - 1 - idx # shape (max_len, n_trajectories) + mask = rev_idx >= 0 # shape (max_len, n_trajectories) + rev_idx[:, 1:] += seq_lengths.cumsum(0)[:-1] + + # Transpose for easier indexing + actions = actions.transpose( + 0, 1 + ) # shape (n_trajectories, max_len, *action_dim) + new_actions = new_actions.transpose( + 0, 1 + ) # shape (n_trajectories, max_len + 1, *action_dim) + states = states.transpose( + 0, 1 + ) # shape (n_trajectories, max_len + 1, *state_dim) + new_states = new_states.transpose( + 0, 1 + ) # shape (n_trajectories, max_len + 2, *state_dim) + rev_idx = rev_idx.transpose(0, 1) + mask = mask.transpose(0, 1) + + # Assign reversed actions to new_actions + new_actions[:, :-1][mask] = actions[mask][rev_idx[mask]] + new_actions[torch.arange(len(trajectories)), seq_lengths] = ( + trajectories.env.n_actions - 1 + ) # FIXME: This can be problematic if action_dim != 1 (e.g. continuous actions) + + # Assign reversed states to new_states + assert torch.all(states[:, -1] == trajectories.env.s0), "Last state must be s0" + new_states[:, 0] = trajectories.env.s0 + new_states[:, 1:-1][mask] = states[:, :-1][mask][rev_idx[mask]] + + # Transpose back + new_actions = new_actions.transpose( + 0, 1 + ) # shape (max_len + 1, n_trajectories, *action_dim) + new_states = new_states.transpose( + 0, 1 + ) # shape (max_len + 2, n_trajectories, *state_dim) + + # TODO: Add below into the test suite to ensure correctness + # new_actions2 = torch.full((max_len + 1, len(trajectories), *trajectories.actions.action_shape), -1).to(actions) + # new_states2 = trajectories.env.sf.repeat(max_len + 2, len(trajectories), 1).to(states) # shape (max_len + 2, n_trajectories, *state_dim) + + # for i in range(len(trajectories)): + # new_actions2[trajectories.when_is_done[i], i] = ( + # trajectories.env.n_actions - 1 + # ) + # new_actions2[ + # : trajectories.when_is_done[i], i + # ] = trajectories.actions.tensor[: trajectories.when_is_done[i], i].flip(0) + + # new_states2[ + # : trajectories.when_is_done[i] + 1, i + # ] = trajectories.states.tensor[: trajectories.when_is_done[i] + 1, i].flip( + # 0 + # ) + + # assert torch.all(new_actions == new_actions2) + # assert torch.all(new_states == new_states2) trajectories_states = trajectories.env.states_from_tensor(new_states) trajectories_actions = trajectories.env.actions_from_tensor(new_actions) @@ -480,7 +548,7 @@ def reverse_backward_trajectories(trajectories: Trajectories) -> Trajectories: states=trajectories_states, conditioning=trajectories.conditioning, actions=trajectories_actions, - when_is_done=new_when_is_done, + when_is_done=trajectories.when_is_done + 1, is_backward=False, log_rewards=trajectories.log_rewards, log_probs=None, # We can't simply pass the trajectories.log_probs