Skip to content

Commit

Permalink
[WIP] commit for sharing the redundancy issue
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Nov 8, 2024
1 parent d7d95ca commit b075d9e
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 34 deletions.
163 changes: 130 additions & 33 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,11 @@ def local_search(
save_logprobs: bool = True,
back_steps: torch.Tensor | None = None,
back_ratio: float | None = None,
use_metropolis_hastings: bool = True,
**policy_kwargs: Any,
) -> Trajectories:
) -> tuple[Trajectories, tuple[torch.Tensor, torch.Tensor]]:
save_logprobs = save_logprobs or use_metropolis_hastings

bs = trajectories.n_trajectories
state_shape = trajectories.states.state_shape
action_shape = trajectories.env.action_shape
Expand All @@ -310,8 +313,6 @@ def local_search(
back_steps,
)

# FIXME: There is a bug in the following code...
# Backward masking is not correctly passed through the local search iterations, I guess...
backward_trajectories = self.backward_sampler.sample_trajectories(
env,
states=trajectories.last_states,
Expand All @@ -321,8 +322,46 @@ def local_search(
**policy_kwargs,
)
# Calculate the forward probability if needed (metropolis-hastings).
if save_logprobs:
raise NotImplementedError("metropolis-hastings is not implemented yet.")
### COPIED FROM `TrajectoryBasedGFlowNet.get_pfs_and_pbs` ###
if use_metropolis_hastings:
### FIXME: I realize that the trajectory needs to be reverted to get the forward probability.
### TODO: Resolve the issue first https://github.com/GFNOrg/torchgfn/issues/109
valid_states = backward_trajectories.states[
~backward_trajectories.states.is_sink_state
]
valid_actions = backward_trajectories.actions[
~backward_trajectories.actions.is_dummy
]

if backward_trajectories.conditioning is not None:
cond_dim = (-1,) * len(backward_trajectories.conditioning.shape)
traj_len = backward_trajectories.states.tensor.shape[0]
masked_cond = backward_trajectories.conditioning.unsqueeze(0).expand(
(traj_len,) + cond_dim
)[~backward_trajectories.states.is_sink_state]

# Here, we pass all valid states, i.e., non-sink states.
with has_conditioning_exception_handler("pf", self.estimator):
estimator_outputs = self.estimator(valid_states, masked_cond)
else:
# Here, we pass all valid states, i.e., non-sink states.
with no_conditioning_exception_handler("pf", self.estimator):
estimator_outputs = self.estimator(valid_states)

# Calculates the log PF of the actions sampled off policy.
valid_log_pf_actions = self.estimator.to_probability_distribution(
valid_states, estimator_outputs
).log_prob(
valid_actions.tensor
) # Using the actions sampled off-policy.
log_pf_backward_trajectories = torch.full_like(
backward_trajectories.actions.tensor[..., 0],
fill_value=0.0,
dtype=torch.float,
)
log_pf_backward_trajectories[
~backward_trajectories.actions.is_dummy
] = valid_log_pf_actions

all_states = backward_trajectories.to_states()
junction_states = all_states[
Expand All @@ -338,9 +377,6 @@ def local_search(
save_logprobs=save_logprobs,
**policy_kwargs,
)
# Calculate backward probability if needed (metropolis-hastings).
if save_logprobs:
raise NotImplementedError("metropolis-hastings is not implemented yet.")

# Obtain full trajectories by concatenating the backward and forward parts.
trajectories_dones = (
Expand All @@ -358,7 +394,7 @@ def local_search(
) # Episodic reward is assumed.
trajectories_logprobs = None # TODO

for i in range(bs):
for i in range(bs): # FIXME: Can we vectorize this?
# Backward part
back_source_idx = backward_trajectories.when_is_done[i]
n_back = back_source_idx - K[i]
Expand All @@ -367,7 +403,7 @@ def local_search(
].flip(0)

# FIXME: This is not correct in general...
# Because the action index may not be consistent with the forward pass.
# Because the action index may not be consistent between the forward and backward.
trajectories_actions_tsr[:n_back, i] = backward_trajectories.actions.tensor[
back_source_idx - n_back : back_source_idx, i
].flip(0)
Expand All @@ -380,8 +416,6 @@ def local_search(
trajectories_actions_tsr[
n_back : n_back + len_recon - 1, i
] = recon_trajectories.actions.tensor[: len_recon - 1, i]
if save_logprobs: # concatenate log_probs
raise NotImplementedError("metropolis-hastings is not implemented yet.")

new_trajectories = Trajectories(
env=env,
Expand All @@ -394,7 +428,82 @@ def local_search(
log_probs=trajectories_logprobs, # TODO: Support log_probs (`None` for now)
)

return new_trajectories
### COPIED FROM `TrajectoryBasedGFlowNet.get_pfs_and_pbs` ###
if use_metropolis_hastings:
valid_states = new_trajectories.states[
~new_trajectories.states.is_sink_state
]
valid_actions = new_trajectories.actions[~new_trajectories.actions.is_dummy]

non_initial_valid_states = valid_states[~valid_states.is_initial_state]
non_exit_valid_actions = valid_actions[~valid_actions.is_exit]

# Using all non-initial states, calculate the backward policy, and the logprobs
# of those actions.
if new_trajectories.conditioning is not None:
# We need to index the conditioning vector to broadcast over the states.
cond_dim = (-1,) * len(new_trajectories.conditioning.shape)
traj_len = new_trajectories.states.tensor.shape[0]
masked_cond = new_trajectories.conditioning.unsqueeze(0).expand(
(traj_len,) + cond_dim
)[~new_trajectories.states.is_sink_state][
~valid_states.is_initial_state
]

# Pass all valid states, i.e., non-sink states, except the initial state.
with has_conditioning_exception_handler(
"pb", self.backward_sampler.estimator
):
estimator_outputs = self.backward_sampler.estimator(
non_initial_valid_states, masked_cond
)
else:
# Pass all valid states, i.e., non-sink states, except the initial state.
with no_conditioning_exception_handler(
"pb", self.backward_sampler.estimator
):
estimator_outputs = self.backward_sampler.estimator(
non_initial_valid_states
)

valid_log_pb_actions = (
self.backward_sampler.estimator.to_probability_distribution(
non_initial_valid_states, estimator_outputs
).log_prob(non_exit_valid_actions.tensor)
)

log_pb_new_trajectories = torch.full_like(
new_trajectories.actions.tensor[..., 0],
fill_value=0.0,
dtype=torch.float,
)
log_pb_new_trajectories_slice = torch.full_like(
valid_actions.tensor[..., 0], fill_value=0.0, dtype=torch.float
)
log_pb_new_trajectories_slice[~valid_actions.is_exit] = valid_log_pb_actions
log_pb_new_trajectories[
~new_trajectories.actions.is_dummy
] = log_pb_new_trajectories_slice

# TODO: Implement Metropolis-Hastings acceptance criterion.
# p(x->s'->x') = p_B(x->s')p_F(s'->x')
# p(x'->s'->x) = p_B(x'->s')p_F(s'->x)
# The acceptance ratio is
# min(1, R(x')p(x->s'->x') / R(x)p(x'->s'-> x))
# Note that
# p(x->s'->x') / p(x'->s'-> x))
# = p_B(x->s')p_F(s'->x') / p_B(x'->s')p_F(s'->x)
# = p_B(x->s'->s0)p_F(s0->s'->x') / p_B(x'->s'->s0)p_F(s0->s'->x)
# = p_B(tau|x)p_F(tau') / p_B(tau'|x')p_F(tau)

# Calculate the acceptance ratio here.
is_updated = None # Sample here
else:
prev_log_rewards = trajectories.log_rewards
new_log_rewards = new_trajectories.log_rewards
is_updated = prev_log_rewards <= new_log_rewards

return new_trajectories, is_updated

def sample_trajectories(
self,
Expand Down Expand Up @@ -452,38 +561,26 @@ def sample_trajectories(
n = trajectories.n_trajectories

search_indices = torch.arange(n, device=trajectories.states.device)
for it in range(n_local_search_loops):
for it in range(n_local_search_loops - 1):
# Search phase
ls_trajectories = self.local_search(
ls_trajectories, is_updated = self.local_search(
env,
trajectories[search_indices],
conditioning,
save_estimator_outputs,
save_logprobs or use_metropolis_hastings,
save_logprobs,
back_steps,
back_ratio,
use_metropolis_hastings,
**policy_kwargs,
)
trajectories.extend(
ls_trajectories
) # Store all regardless of the acceptance.

# Selection phase
if not use_metropolis_hastings:
last_indices = torch.arange(
n * (it + 1), n * (it + 2), device=trajectories.states.device
)
prev_log_rewards = trajectories.log_rewards[search_indices] # type: ignore # FIXME: pyright error
new_log_rewards = ls_trajectories.log_rewards
update_indices = prev_log_rewards <= new_log_rewards
search_indices[update_indices] = last_indices[update_indices]
else: # Metropolis-Hastings acceptance criterion
# TODO: Implement Metropolis-Hastings acceptance criterion.
# We need p(x -> s -> x') = p_B(x -> s) * p_F(s -> x')
# and p(x' -> s' -> x) = p_B(x' -> s') * p_F(s' -> x)
# to calculate the acceptance ratio.
raise NotImplementedError(
"Metropolis-Hastings acceptance criterion is not implemented."
)
last_indices = torch.arange(
n * (it + 1), n * (it + 2), device=trajectories.states.device
)
search_indices[is_updated] = last_indices[is_updated]

return trajectories
2 changes: 1 addition & 1 deletion src/gfn/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def extend_with_sf(self, required_first_dim: int) -> None:
f"extend_with_sf is not implemented for batch shapes {self.batch_shape}"
)

def compare(self, other: torch.tensor) -> torch.Tensor:
def compare(self, other: torch.Tensor) -> torch.Tensor:
"""Computes elementwise equality between state tensor with an external tensor.
Args:
Expand Down
1 change: 1 addition & 0 deletions tutorials/examples/train_hypergrid_simple_ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def main(args):
epsilon=args.epsilon,
n_local_search_loops=args.n_local_search_loops,
back_ratio=0.5,
use_metropolis_hastings=False,
)
optimizer.zero_grad()
loss = gflownet.loss(env, trajectories)
Expand Down

0 comments on commit b075d9e

Please sign in to comment.