Skip to content

Commit

Permalink
draft commit
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Oct 29, 2024
1 parent 57cc269 commit 1ccf16c
Show file tree
Hide file tree
Showing 4 changed files with 363 additions and 15 deletions.
27 changes: 17 additions & 10 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,29 @@ def __init__(
if when_is_done is not None
else torch.full(size=(0,), fill_value=-1, dtype=torch.long)
)
assert self.when_is_done.shape == (self.n_trajectories,) and self.when_is_done.dtype == torch.long
assert (
self.when_is_done.shape == (self.n_trajectories,)
and self.when_is_done.dtype == torch.long
)

self._log_rewards = (
log_rewards
if log_rewards is not None
else torch.full(size=(0,), fill_value=0, dtype=torch.float)
)
assert self._log_rewards.shape == (self.n_trajectories,) and self._log_rewards.dtype == torch.float
assert (
self._log_rewards.shape == (self.n_trajectories,)
and self._log_rewards.dtype == torch.float
)

if log_probs is not None:
assert log_probs.shape == (self.max_length, self.n_trajectories) and log_probs.dtype == torch.float
if log_probs is not None and log_probs.shape != (0, 0):
assert (
log_probs.shape == (self.max_length, self.n_trajectories)
and log_probs.dtype == torch.float
)
else:
log_probs = torch.full(size=(0, 0), fill_value=0, dtype=torch.float)
self.log_probs = log_probs
self.log_probs = log_probs

self.estimator_outputs = estimator_outputs
if self.estimator_outputs is not None:
Expand Down Expand Up @@ -207,15 +216,13 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories:
)

@staticmethod
def extend_log_probs(
log_probs: torch.Tensor, new_max_length: int
) -> torch.Tensor:
def extend_log_probs(log_probs: torch.Tensor, new_max_length: int) -> torch.Tensor:
"""Extend the log_probs matrix by adding 0 until the required length is reached.
Args:
log_probs: The log_probs tensor of shape (max_length, n_trajectories) to extend.
new_max_length: The new length of the log_probs tensor.
Returns: The extended log_probs tensor of shape (new_max_length, n_trajectories).
"""
Expand Down
232 changes: 229 additions & 3 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def sample_actions(
"""Samples actions from the given states.
Args:
estimator: A GFNModule to pass to the probability distribution calculator.
env: The environment to sample actions from.
states: A batch of states.
conditioning: An optional tensor of conditioning information.
Expand Down Expand Up @@ -203,11 +202,11 @@ def sample_trajectories(
all_estimator_outputs.append(estimator_outputs_padded)

actions[~dones] = valid_actions
trajectories_actions.append(actions)
if save_logprobs:
# When off_policy, actions_log_probs are None.
log_probs[~dones] = actions_log_probs
trajectories_actions.append(actions)
trajectories_logprobs.append(log_probs)
trajectories_logprobs.append(log_probs)

if self.estimator.is_backward:
new_states = env._backward_step(states, actions)
Expand Down Expand Up @@ -264,3 +263,230 @@ def sample_trajectories(
)

return trajectories


class LocalSearchSampler(Sampler):
"""Sampler equipped with local search capabilities.
The local search operation is based on back-and-forth heuristic, first proposed
by Zhang et al. 2022 (https://arxiv.org/abs/2202.01361) for negative sampling
and further explored its effectiveness in various applications by Kim et al. 2023
(https://arxiv.org/abs/2310.02710).
Attributes:
estimator: the submitted PolicyEstimator for the forward pass.
pb_estimator: the PolicyEstimator for the backward pass.
"""

def __init__(self, estimator: GFNModule, pb_estimator: GFNModule):
super().__init__(estimator)
self.backward_sampler = Sampler(pb_estimator)

def local_search(
self,
env: Env,
trajectories: Trajectories,
conditioning: torch.Tensor | None = None,
save_estimator_outputs: bool = False,
save_logprobs: bool = True,
back_steps: torch.Tensor | None = None,
back_ratio: float | None = None,
**policy_kwargs: Any,
) -> Trajectories:
bs = trajectories.n_trajectories
state_shape = trajectories.states.state_shape
action_shape = trajectories.env.action_shape

# K-step backward sampling with the backward estimator,
# where K is the number of backward steps used in https://arxiv.org/abs/2202.01361.
if back_steps is None:
assert (
back_ratio is not None
), "Either kwarg `back_steps` or `back_ratio` must be specified"
K = torch.ceil(back_ratio * (trajectories.when_is_done - 1)).long()
else:
K = torch.where(
back_steps > trajectories.when_is_done,
trajectories.when_is_done,
back_steps,
)

# FIXME: There is a bug in the following code...
# Backward masking is not correctly passed through the local search iterations, I guess...
backward_trajectories = self.backward_sampler.sample_trajectories(
env,
states=trajectories.last_states,
conditioning=conditioning,
save_estimator_outputs=save_estimator_outputs,
save_logprobs=save_logprobs,
**policy_kwargs,
)
# Calculate the forward probability if needed (metropolis-hastings).
if save_logprobs:
raise NotImplementedError("metropolis-hastings is not implemented yet.")

all_states = backward_trajectories.to_states()
junction_states = all_states[
torch.arange(bs, device=all_states.device) + bs * K
]

### Reconstructing with self.estimator
recon_trajectories = super().sample_trajectories(
env,
states=junction_states,
conditioning=conditioning,
save_estimator_outputs=save_estimator_outputs,
save_logprobs=save_logprobs,
**policy_kwargs,
)
# Calculate backward probability if needed (metropolis-hastings).
if save_logprobs:
raise NotImplementedError("metropolis-hastings is not implemented yet.")

# Obtain full trajectories by concatenating the backward and forward parts.
trajectories_dones = (
backward_trajectories.when_is_done - K + recon_trajectories.when_is_done
)
max_traj_len = trajectories_dones.max() + 1
trajectories_states_tsr = torch.full((max_traj_len, bs, *state_shape), -1).to(
all_states.tensor
)
trajectories_actions_tsr = torch.full(
(max_traj_len - 1, bs, *action_shape), -1
).to(all_states.tensor)
trajectories_log_rewards = (
recon_trajectories.log_rewards
) # Episodic reward is assumed.
trajectories_logprobs = None # TODO

len_recon = recon_trajectories.states.tensor.shape[0]
for i in range(bs):
# Backward part
back_source_idx = backward_trajectories.when_is_done[i]
n_back = back_source_idx - K[i]
trajectories_states_tsr[:n_back, i] = backward_trajectories.states.tensor[
back_source_idx - n_back + 1 : back_source_idx + 1, i
].flip(0)

# FIXME: This is not correct in general...
# Because the action index may not be consistent with the forward pass.
trajectories_actions_tsr[:n_back, i] = backward_trajectories.actions.tensor[
back_source_idx - n_back : back_source_idx, i
].flip(0)

# Forward part
trajectories_states_tsr[n_back : n_back + len_recon, i] = (
recon_trajectories.states.tensor[:, i]
)
trajectories_actions_tsr[n_back : n_back + len_recon - 1, i] = (
recon_trajectories.actions.tensor[: len_recon - 1, i]
)
if save_logprobs: # concatenate log_probs
raise NotImplementedError("metropolis-hastings is not implemented yet.")

new_trajectories = Trajectories(
env=env,
states=env.States(trajectories_states_tsr),
conditioning=conditioning,
actions=env.Actions(trajectories_actions_tsr),
when_is_done=trajectories_dones,
is_backward=False,
# FIXME: This is weird... since the trajectory contains
# both backward and forward parts.
# Maybe calculate log_pfs for the backward part -> and set is_backward=True?
log_rewards=trajectories_log_rewards,
log_probs=trajectories_logprobs, # TODO: Support log_probs (`None` for now)
)

return new_trajectories

def sample_trajectories(
self,
env: Env,
n: Optional[int] = None,
states: Optional[States] = None,
conditioning: Optional[torch.Tensor] = None,
save_estimator_outputs: bool = False, # FIXME: currently not work when this is True
save_logprobs: bool = True, # TODO: Support save_logprobs=True
n_local_search_loops: int = 0,
back_steps: torch.Tensor | None = None,
back_ratio: float | None = None,
use_metropolis_hastings: bool = False,
**policy_kwargs: Any,
) -> Trajectories:
"""Sample trajectories sequentially with optional local search.
Args:
env: The environment to sample trajectories from.
n: If given, a batch of n_trajectories will be sampled all
starting from the environment's s_0.
states: If given, trajectories would start from such states. Otherwise,
trajectories are sampled from $s_o$ and n_trajectories must be provided.
conditioning: An optional tensor of conditioning information.
save_estimator_outputs: If True, the estimator outputs will be returned. This
is useful for off-policy training with tempered policy.
save_logprobs: If True, calculates and saves the log probabilities of sampled
actions. This is useful for on-policy training.
local_search: If True, applies local search operation.
back_steps: The number of backward steps.
back_ratio: The ratio of the number of backward steps to the length of the trajectory.
use_metropolis_hastings: If True, applies Metropolis-Hastings acceptance criterion.
policy_kwargs: keyword arguments to be passed to the
`to_probability_distribution` method of the estimator. For example, for
DiscretePolicyEstimators, the kwargs can contain the `temperature`
parameter, `epsilon`, and `sf_bias`. In the continuous case these
kwargs will be user defined. This can be used to, for example, sample
off-policy.
Returns: A Trajectories object representing the batch of sampled trajectories,
where the batch size is n * (1 + n_local_search_loops).
"""

trajectories = super().sample_trajectories(
env,
n,
states,
conditioning,
save_estimator_outputs,
save_logprobs or use_metropolis_hastings,
**policy_kwargs,
)

if n is None:
n = trajectories.n_trajectories

search_indices = torch.arange(n, device=trajectories.states.device)
for it in range(n_local_search_loops):
# Search phase
ls_trajectories = self.local_search(
env,
trajectories[search_indices],
conditioning,
save_estimator_outputs,
save_logprobs or use_metropolis_hastings,
back_steps,
back_ratio,
**policy_kwargs,
)
trajectories.extend(
ls_trajectories
) # Store all regardless of the acceptance.

# Selection phase
if not use_metropolis_hastings:
last_indices = torch.arange(
n * (it + 1), n * (it + 2), device=trajectories.states.device
)
prev_log_rewards = trajectories.log_rewards[search_indices]
new_log_rewards = ls_trajectories.log_rewards
update_indices = prev_log_rewards <= new_log_rewards
search_indices[update_indices] = last_indices[update_indices]
else: # Metropolis-Hastings acceptance criterion
# TODO: Implement Metropolis-Hastings acceptance criterion.
# We need p(x -> s -> x') = p_B(x -> s) * p_F(s -> x')
# and p(x' -> s' -> x) = p_B(x' -> s') * p_F(s' -> x)
# to calculate the acceptance ratio.
raise NotImplementedError(
"Metropolis-Hastings acceptance criterion is not implemented."
)

return trajectories
3 changes: 1 addition & 2 deletions tutorials/examples/train_hypergrid_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def main(args):
sampler = Sampler(estimator=pf_estimator)

# Move the gflownet to the GPU.
if torch.cuda.is_available():
gflownet = gflownet.to("cuda")
gflownet = gflownet.to(device_str)

# Policy parameters have their own LR. Log Z gets dedicated learning rate
# (typically higher).
Expand Down
Loading

0 comments on commit 1ccf16c

Please sign in to comment.