From 0cc32f71582c8e566718e3319e20fc34c9ec94c1 Mon Sep 17 00:00:00 2001 From: hyeok9855 Date: Wed, 27 Nov 2024 05:02:04 +0900 Subject: [PATCH] add and fix the --- src/gfn/containers/trajectories.py | 65 +++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 2 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 6363c925..bdbd7d07 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -5,9 +5,8 @@ if TYPE_CHECKING: from gfn.actions import Actions from gfn.env import Env - from gfn.states import States, DiscreteStates + from gfn.states import States -import numpy as np import torch from gfn.containers.base import Container @@ -428,6 +427,68 @@ def to_non_initial_intermediary_and_terminating_states( conditioning, ) + @staticmethod + def reverse_backward_trajectories(trajectories: Trajectories) -> Trajectories: + """Reverses a backward trajectory""" + # FIXME: This method is not compatible with continuous GFN. + + assert trajectories.is_backward, "Trajectories must be backward." + new_actions = torch.full( + ( + trajectories.max_length + 1, + len(trajectories), + *trajectories.actions.action_shape, + ), + -1, + ) + + # env.sf should never be None unless something went wrong during class + # instantiation. + if trajectories.env.sf is None: + raise AttributeError( + "Something went wrong during the instantiation of environment {}".format( + trajectories.env + ) + ) + + new_when_is_done = trajectories.when_is_done + 1 + new_states = trajectories.env.sf.repeat( + new_when_is_done.max() + 1, len(trajectories), 1 + ) + + # FIXME: Can we vectorize this? + # FIXME: Also, loop over batch or sequence? + for i in range(len(trajectories)): + new_actions[trajectories.when_is_done[i], i] = ( + trajectories.env.n_actions - 1 + ) + new_actions[ + : trajectories.when_is_done[i], i + ] = trajectories.actions.tensor[: trajectories.when_is_done[i], i].flip(0) + + new_states[ + : trajectories.when_is_done[i] + 1, i + ] = trajectories.states.tensor[: trajectories.when_is_done[i] + 1, i].flip( + 0 + ) + + trajectories_states = trajectories.env.States(new_states) + trajectories_actions = trajectories.env.Actions(new_actions) + + return Trajectories( + env=trajectories.env, + states=trajectories_states, + conditioning=trajectories.conditioning, + actions=trajectories_actions, + when_is_done=new_when_is_done, + is_backward=False, + log_rewards=trajectories.log_rewards, + log_probs=None, # We can't simply pass the trajectories.log_probs + # Since `log_probs` is assumed to be the forward log probabilities. + # FIXME: To resolve this, we can save log_pfs and log_pbs in the trajectories object. + estimator_outputs=None, # Same as `log_probs`. + ) + def pad_dim0_to_target(a: torch.Tensor, target_dim0: int) -> torch.Tensor: """Pads tensor a to match the dimention of b."""