Skip to content

Commit

Permalink
add and fix the
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Nov 26, 2024
1 parent 8481673 commit 0cc32f7
Showing 1 changed file with 63 additions and 2 deletions.
65 changes: 63 additions & 2 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
if TYPE_CHECKING:
from gfn.actions import Actions
from gfn.env import Env
from gfn.states import States, DiscreteStates
from gfn.states import States

import numpy as np
import torch

from gfn.containers.base import Container
Expand Down Expand Up @@ -428,6 +427,68 @@ def to_non_initial_intermediary_and_terminating_states(
conditioning,
)

@staticmethod
def reverse_backward_trajectories(trajectories: Trajectories) -> Trajectories:
"""Reverses a backward trajectory"""
# FIXME: This method is not compatible with continuous GFN.

assert trajectories.is_backward, "Trajectories must be backward."
new_actions = torch.full(
(
trajectories.max_length + 1,
len(trajectories),
*trajectories.actions.action_shape,
),
-1,
)

# env.sf should never be None unless something went wrong during class
# instantiation.
if trajectories.env.sf is None:
raise AttributeError(
"Something went wrong during the instantiation of environment {}".format(
trajectories.env
)
)

new_when_is_done = trajectories.when_is_done + 1
new_states = trajectories.env.sf.repeat(
new_when_is_done.max() + 1, len(trajectories), 1
)

# FIXME: Can we vectorize this?
# FIXME: Also, loop over batch or sequence?
for i in range(len(trajectories)):
new_actions[trajectories.when_is_done[i], i] = (
trajectories.env.n_actions - 1
)
new_actions[
: trajectories.when_is_done[i], i
] = trajectories.actions.tensor[: trajectories.when_is_done[i], i].flip(0)

new_states[
: trajectories.when_is_done[i] + 1, i
] = trajectories.states.tensor[: trajectories.when_is_done[i] + 1, i].flip(
0
)

trajectories_states = trajectories.env.States(new_states)
trajectories_actions = trajectories.env.Actions(new_actions)

return Trajectories(
env=trajectories.env,
states=trajectories_states,
conditioning=trajectories.conditioning,
actions=trajectories_actions,
when_is_done=new_when_is_done,
is_backward=False,
log_rewards=trajectories.log_rewards,
log_probs=None, # We can't simply pass the trajectories.log_probs
# Since `log_probs` is assumed to be the forward log probabilities.
# FIXME: To resolve this, we can save log_pfs and log_pbs in the trajectories object.
estimator_outputs=None, # Same as `log_probs`.
)


def pad_dim0_to_target(a: torch.Tensor, target_dim0: int) -> torch.Tensor:
"""Pads tensor a to match the dimention of b."""
Expand Down

0 comments on commit 0cc32f7

Please sign in to comment.