Skip to content

Commit

Permalink
Merge branch 'hyeok9855/reverse_backward_trajectories' into hyeok9855…
Browse files Browse the repository at this point in the history
…/local-search
  • Loading branch information
hyeok9855 committed Nov 26, 2024
2 parents b075d9e + 0cc32f7 commit 4e11c27
Show file tree
Hide file tree
Showing 12 changed files with 4,026 additions and 3,677 deletions.
69 changes: 65 additions & 4 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
if TYPE_CHECKING:
from gfn.actions import Actions
from gfn.env import Env
from gfn.states import States, DiscreteStates
from gfn.states import States

import numpy as np
import torch

from gfn.containers.base import Container
Expand Down Expand Up @@ -122,15 +121,15 @@ def __repr__(self) -> str:
for traj in states[:10]:
one_traj_repr = []
for step in traj:
one_traj_repr.append(str(step.numpy()))
one_traj_repr.append(str(step.cpu().numpy()))
if step.equal(self.env.s0 if self.is_backward else self.env.sf):
break
trajectories_representation += "-> ".join(one_traj_repr) + "\n"
return (
f"Trajectories(n_trajectories={self.n_trajectories}, max_length={self.max_length}, First 10 trajectories:"
+ f"states=\n{trajectories_representation}"
# + f"actions=\n{self.actions.tensor.squeeze().transpose(0, 1)[:10].numpy()}, "
+ f"when_is_done={self.when_is_done[:10].numpy()})"
+ f"when_is_done={self.when_is_done[:10].cpu().numpy()})"
)

@property
Expand Down Expand Up @@ -428,6 +427,68 @@ def to_non_initial_intermediary_and_terminating_states(
conditioning,
)

@staticmethod
def reverse_backward_trajectories(trajectories: Trajectories) -> Trajectories:
"""Reverses a backward trajectory"""
# FIXME: This method is not compatible with continuous GFN.

assert trajectories.is_backward, "Trajectories must be backward."
new_actions = torch.full(
(
trajectories.max_length + 1,
len(trajectories),
*trajectories.actions.action_shape,
),
-1,
)

# env.sf should never be None unless something went wrong during class
# instantiation.
if trajectories.env.sf is None:
raise AttributeError(
"Something went wrong during the instantiation of environment {}".format(
trajectories.env
)
)

new_when_is_done = trajectories.when_is_done + 1
new_states = trajectories.env.sf.repeat(
new_when_is_done.max() + 1, len(trajectories), 1
)

# FIXME: Can we vectorize this?
# FIXME: Also, loop over batch or sequence?
for i in range(len(trajectories)):
new_actions[trajectories.when_is_done[i], i] = (
trajectories.env.n_actions - 1
)
new_actions[
: trajectories.when_is_done[i], i
] = trajectories.actions.tensor[: trajectories.when_is_done[i], i].flip(0)

new_states[
: trajectories.when_is_done[i] + 1, i
] = trajectories.states.tensor[: trajectories.when_is_done[i] + 1, i].flip(
0
)

trajectories_states = trajectories.env.States(new_states)
trajectories_actions = trajectories.env.Actions(new_actions)

return Trajectories(
env=trajectories.env,
states=trajectories_states,
conditioning=trajectories.conditioning,
actions=trajectories_actions,
when_is_done=new_when_is_done,
is_backward=False,
log_rewards=trajectories.log_rewards,
log_probs=None, # We can't simply pass the trajectories.log_probs
# Since `log_probs` is assumed to be the forward log probabilities.
# FIXME: To resolve this, we can save log_pfs and log_pbs in the trajectories object.
estimator_outputs=None, # Same as `log_probs`.
)


def pad_dim0_to_target(a: torch.Tensor, target_dim0: int) -> torch.Tensor:
"""Pads tensor a to match the dimention of b."""
Expand Down
106 changes: 4 additions & 102 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@
from gfn.modules import GFNModule
from gfn.samplers import Sampler
from gfn.states import States
from gfn.utils.common import has_log_probs
from gfn.utils.handlers import (
has_conditioning_exception_handler,
no_conditioning_exception_handler,
)
from gfn.utils.prob_calculations import get_trajectory_pfs_and_pbs

TrainingSampleType = TypeVar(
"TrainingSampleType", bound=Union[Container, tuple[States, ...]]
Expand Down Expand Up @@ -145,6 +141,7 @@ def get_pfs_and_pbs(
trajectories: Trajectories to evaluate.
fill_value: Value to use for invalid states (i.e. $s_f$ that is added to
shorter trajectories).
recalculate_all_logprobs: Whether to re-evaluate all logprobs.
Returns: A tuple of float tensors of shape (max_length, n_trajectories) containing
the log_pf and log_pb for each action in each trajectory. The first one can be None.
Expand All @@ -153,103 +150,9 @@ def get_pfs_and_pbs(
ValueError: if the trajectories are backward.
AssertionError: when actions and states dimensions mismatch.
"""
# fill value is the value used for invalid states (sink state usually)
if trajectories.is_backward:
raise ValueError("Backward trajectories are not supported")

valid_states = trajectories.states[~trajectories.states.is_sink_state]
valid_actions = trajectories.actions[~trajectories.actions.is_dummy]

# uncomment next line for debugging
# assert trajectories.states.is_sink_state[:-1].equal(trajectories.actions.is_dummy)

if valid_states.batch_shape != tuple(valid_actions.batch_shape):
raise AssertionError("Something wrong happening with log_pf evaluations")

if has_log_probs(trajectories) and not recalculate_all_logprobs:
log_pf_trajectories = trajectories.log_probs
else:
if (
trajectories.estimator_outputs is not None
and not recalculate_all_logprobs
):
estimator_outputs = trajectories.estimator_outputs[
~trajectories.actions.is_dummy
]
else:
if trajectories.conditioning is not None:
cond_dim = (-1,) * len(trajectories.conditioning.shape)
traj_len = trajectories.states.tensor.shape[0]
masked_cond = trajectories.conditioning.unsqueeze(0).expand(
(traj_len,) + cond_dim
)[~trajectories.states.is_sink_state]

# Here, we pass all valid states, i.e., non-sink states.
with has_conditioning_exception_handler("pf", self.pf):
estimator_outputs = self.pf(valid_states, masked_cond)
else:
# Here, we pass all valid states, i.e., non-sink states.
with no_conditioning_exception_handler("pf", self.pf):
estimator_outputs = self.pf(valid_states)

# Calculates the log PF of the actions sampled off policy.
valid_log_pf_actions = self.pf.to_probability_distribution(
valid_states, estimator_outputs
).log_prob(
valid_actions.tensor
) # Using the actions sampled off-policy.
log_pf_trajectories = torch.full_like(
trajectories.actions.tensor[..., 0],
fill_value=fill_value,
dtype=torch.float,
)
log_pf_trajectories[~trajectories.actions.is_dummy] = valid_log_pf_actions

non_initial_valid_states = valid_states[~valid_states.is_initial_state]
non_exit_valid_actions = valid_actions[~valid_actions.is_exit]

# Using all non-initial states, calculate the backward policy, and the logprobs
# of those actions.
if trajectories.conditioning is not None:
# We need to index the conditioning vector to broadcast over the states.
cond_dim = (-1,) * len(trajectories.conditioning.shape)
traj_len = trajectories.states.tensor.shape[0]
masked_cond = trajectories.conditioning.unsqueeze(0).expand(
(traj_len,) + cond_dim
)[~trajectories.states.is_sink_state][~valid_states.is_initial_state]

# Pass all valid states, i.e., non-sink states, except the initial state.
with has_conditioning_exception_handler("pb", self.pb):
estimator_outputs = self.pb(non_initial_valid_states, masked_cond)
else:
# Pass all valid states, i.e., non-sink states, except the initial state.
with no_conditioning_exception_handler("pb", self.pb):
estimator_outputs = self.pb(non_initial_valid_states)

valid_log_pb_actions = self.pb.to_probability_distribution(
non_initial_valid_states, estimator_outputs
).log_prob(non_exit_valid_actions.tensor)

log_pb_trajectories = torch.full_like(
trajectories.actions.tensor[..., 0],
fill_value=fill_value,
dtype=torch.float,
return get_trajectory_pfs_and_pbs(
self.pf, self.pb, trajectories, fill_value, recalculate_all_logprobs
)
log_pb_trajectories_slice = torch.full_like(
valid_actions.tensor[..., 0], fill_value=fill_value, dtype=torch.float
)
log_pb_trajectories_slice[~valid_actions.is_exit] = valid_log_pb_actions
log_pb_trajectories[~trajectories.actions.is_dummy] = log_pb_trajectories_slice

assert log_pf_trajectories.shape == (
trajectories.max_length,
trajectories.n_trajectories,
)
assert log_pb_trajectories.shape == (
trajectories.max_length,
trajectories.n_trajectories,
)
return log_pf_trajectories, log_pb_trajectories

def get_trajectories_scores(
self,
Expand All @@ -265,7 +168,6 @@ def get_trajectories_scores(
Returns: A tuple of float tensors of shape (n_trajectories,)
containing the total log_pf, total log_pb, and the total
log-likelihood of the trajectories.
"""
log_pf_trajectories, log_pb_trajectories = self.get_pfs_and_pbs(
trajectories, recalculate_all_logprobs=recalculate_all_logprobs
Expand Down
67 changes: 21 additions & 46 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
has_conditioning_exception_handler,
no_conditioning_exception_handler,
)
from gfn.utils.prob_calculations import get_transition_pfs_and_pbs


def check_compatibility(states, actions, transitions):
Expand Down Expand Up @@ -78,6 +79,13 @@ def logF_parameters(self):
)
)

def get_pfs_and_pbs(
self, transitions: Transitions, recalculate_all_logprobs: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
return get_transition_pfs_and_pbs(
self.pf, self.pb, transitions, recalculate_all_logprobs
)

def get_scores(
self, env: Env, transitions: Transitions, recalculate_all_logprobs: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand All @@ -101,70 +109,39 @@ def get_scores(
"""
if transitions.is_backward:
raise ValueError("Backward transitions are not supported")

states = transitions.states
actions = transitions.actions

# uncomment next line for debugging
# assert transitions.states.is_sink_state.equal(transitions.actions.is_dummy)
check_compatibility(states, actions, transitions)

if has_log_probs(transitions) and not recalculate_all_logprobs:
valid_log_pf_actions = transitions.log_probs
else:
# Evaluate the log PF of the actions, with optional conditioning.
# TODO: Inefficient duplication in case of tempered policy
# The Transitions container should then have some
# estimator_outputs attribute as well, to avoid duplication here ?
# See (#156).
if transitions.conditioning is not None:
with has_conditioning_exception_handler("pf", self.pf):
module_output = self.pf(states, transitions.conditioning)
else:
with no_conditioning_exception_handler("pf", self.pf):
module_output = self.pf(states)

valid_log_pf_actions = self.pf.to_probability_distribution(
states, module_output
).log_prob(actions.tensor)
log_pf_actions, log_pb_actions = self.get_pfs_and_pbs(
transitions, recalculate_all_logprobs
)

# LogF is potentially a conditional computation.
if transitions.conditioning is not None:
with has_conditioning_exception_handler("logF", self.logF):
valid_log_F_s = self.logF(states, transitions.conditioning).squeeze(-1)
log_F_s = self.logF(states, transitions.conditioning).squeeze(-1)
else:
with no_conditioning_exception_handler("logF", self.logF):
valid_log_F_s = self.logF(states).squeeze(-1)
log_F_s = self.logF(states).squeeze(-1)

if self.forward_looking:
log_rewards = env.log_reward(states) # TODO: RM unsqueeze(-1) ?
if math.isfinite(self.log_reward_clip_min):
log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)
valid_log_F_s = valid_log_F_s + log_rewards
log_F_s = log_F_s + log_rewards

preds = valid_log_pf_actions + valid_log_F_s
targets = torch.zeros_like(preds)
preds = log_pf_actions + log_F_s

# uncomment next line for debugging
# assert transitions.next_states.is_sink_state.equal(transitions.is_done)

# automatically removes invalid transitions (i.e. s_f -> s_f)
valid_next_states = transitions.next_states[~transitions.is_done]
non_exit_actions = actions[~actions.is_exit]

# Evaluate the log PB of the actions, with optional conditioning.
if transitions.conditioning is not None:
with has_conditioning_exception_handler("pb", self.pb):
module_output = self.pb(
valid_next_states, transitions.conditioning[~transitions.is_done]
)
else:
with no_conditioning_exception_handler("pb", self.pb):
module_output = self.pb(valid_next_states)

valid_log_pb_actions = self.pb.to_probability_distribution(
valid_next_states, module_output
).log_prob(non_exit_actions.tensor)

valid_transitions_is_done = transitions.is_done[
~transitions.states.is_sink_state
]
Expand All @@ -179,23 +156,21 @@ def get_scores(
with no_conditioning_exception_handler("logF", self.logF):
valid_log_F_s_next = self.logF(valid_next_states).squeeze(-1)

targets[~valid_transitions_is_done] = valid_log_pb_actions
log_pb_actions = targets.clone()
targets[~valid_transitions_is_done] += valid_log_F_s_next
log_F_s_next = torch.zeros_like(log_pb_actions)
log_F_s_next[~valid_transitions_is_done] = valid_log_F_s_next
assert transitions.log_rewards is not None
valid_transitions_log_rewards = transitions.log_rewards[
~transitions.states.is_sink_state
]
targets[valid_transitions_is_done] = valid_transitions_log_rewards[
log_F_s_next[valid_transitions_is_done] = valid_transitions_log_rewards[
valid_transitions_is_done
]
targets = log_pb_actions + log_F_s_next

scores = preds - targets

assert valid_log_pf_actions.shape == (transitions.n_transitions,)
assert log_pb_actions.shape == (transitions.n_transitions,)
assert scores.shape == (transitions.n_transitions,)
return valid_log_pf_actions, log_pb_actions, scores
return log_pf_actions, log_pb_actions, scores

def loss(self, env: Env, transitions: Transitions) -> torch.Tensor:
"""Detailed balance loss.
Expand Down
2 changes: 1 addition & 1 deletion src/gfn/gym/discrete_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def get_states_indices(self, states: DiscreteStates) -> torch.Tensor:
return states_indices

def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor:
"""Returns the indices of the terminating states.
"""Get the indices of the terminating states in the canonical ordering from the submitted states.
Args:
states: DiscreteStates object representing the states.
Expand Down
5 changes: 3 additions & 2 deletions src/gfn/gym/hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def update_masks(self, states: DiscreteStates) -> None:
"""Update the masks based on the current states."""
# Not allowed to take any action beyond the environment height, but
# allow early termination.
# TODO: do we need to handle the conditional case here?
states.set_nonexit_action_masks(
states.tensor == self.height - 1,
allow_exit=True,
Expand Down Expand Up @@ -174,9 +175,9 @@ def get_states_indices(self, states: DiscreteStates) -> torch.Tensor:
return indices

def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor:
"""Get the indices of the terminating states in the canonical ordering.
"""Get the indices of the terminating states in the canonical ordering from the submitted states.
Returns the indices of the terminating states in the canonical ordering as a tensor of shape `batch_shape`.
Canonical ordering is returned as a tensor of shape `batch_shape`.
"""
return self.get_states_indices(states)

Expand Down
Loading

0 comments on commit 4e11c27

Please sign in to comment.