Skip to content

Commit

Permalink
draft commit
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Oct 29, 2024
1 parent 8512dce commit d77f82d
Showing 1 changed file with 221 additions and 3 deletions.
224 changes: 221 additions & 3 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def sample_actions(
"""Samples actions from the given states.
Args:
estimator: A GFNModule to pass to the probability distribution calculator.
env: The environment to sample actions from.
states: A batch of states.
conditioning: An optional tensor of conditioning information.
Expand Down Expand Up @@ -206,11 +205,11 @@ def sample_trajectories(
all_estimator_outputs.append(estimator_outputs_padded)

actions[~dones] = valid_actions
trajectories_actions.append(actions)
if save_logprobs:
# When off_policy, actions_log_probs are None.
log_probs[~dones] = actions_log_probs
trajectories_actions.append(actions)
trajectories_logprobs.append(log_probs)
trajectories_logprobs.append(log_probs)

if self.estimator.is_backward:
new_states = env._backward_step(states, actions)
Expand Down Expand Up @@ -267,3 +266,222 @@ def sample_trajectories(
)

return trajectories


class LocalSearchSampler(Sampler):
"""Sampler equipped with local search capabilities.
The local search operation is based on back-and-forth heuristic, first proposed
by Zhang et al. 2022 (https://arxiv.org/abs/2202.01361) for negative sampling
and further explored its effectiveness in various applications by Kim et al. 2023
(https://arxiv.org/abs/2310.02710).
Attributes:
estimator: the submitted PolicyEstimator for the forward pass.
pb_estimator: the PolicyEstimator for the backward pass.
"""

def __init__(self, estimator: GFNModule, pb_estimator: GFNModule):
super().__init__(estimator)
self.pb_estimator = pb_estimator

def sample_actions_backward(
self,
env: Env,
states: States,
conditioning: torch.Tensor | None = None,
save_estimator_outputs: bool = False,
save_logprobs: bool = True,
**policy_kwargs: Any,
) -> Tuple[
Actions,
TT["batch_shape", torch.float] | None,
TT["batch_shape", torch.float] | None,
]:
"""Samples backward actions from the given states.
Args:
env: The environment to sample actions from.
states: A batch of states.
conditioning: An optional tensor of conditioning information.
save_estimator_outputs: If True, the estimator outputs will be returned.
save_logprobs: If True, calculates and saves the log probabilities of sampled
actions.
policy_kwargs: keyword arguments to be passed to the
`to_probability_distribution` method of the estimator. For example, for
DiscretePolicyEstimators, the kwargs can contain the `temperature`
parameter, `epsilon`, and `sf_bias`. In the continuous case these
kwargs will be user defined. This can be used to, for example, sample
off-policy.
When sampling off policy, ensure to `save_estimator_outputs` and not
`calculate logprobs`. Log probabilities are instead calculated during the
computation of `PF` as part of the `GFlowNet` class, and the estimator
outputs are required for estimating the logprobs of these off policy
actions.
Returns:
A tuple of tensors containing:
- An Actions object containing the sampled actions.
- A tensor of shape (*batch_shape,) containing the log probabilities of
the sampled actions under the probability distribution of the given
states.
"""
if conditioning is not None:
with has_conditioning_exception_handler("pb_estimator", self.pb_estimator):
pb_estimator_output = self.pb_estimator(states, conditioning)
else:
with no_conditioning_exception_handler("pb_estimator", self.pb_estimator):
pb_estimator_output = self.pb_estimator(states)

dist = self.pb_estimator.to_probability_distribution(
states, pb_estimator_output, **policy_kwargs
)

with torch.no_grad():
actions = dist.sample()

if save_logprobs:
log_probs = dist.log_prob(actions)
if torch.any(torch.isinf(log_probs)):
raise RuntimeError("Log probabilities are inf. This should not happen.")
else:
log_probs = None

actions = env.actions_from_tensor(actions)

if not save_estimator_outputs:
pb_estimator_output = None

return actions, log_probs, pb_estimator

def sample_trajectories_backward(
self, trajectories: Trajectories, n_steps: TT["batch_shape", torch.int]
) -> List[States]:
"""
Args:
trajectories: A batch of trajectories.
n_steps: The number of steps to backtrack.
"""
partial_trajectories = trajectories # TODO
return partial_trajectories

def local_search(
self,
env: Env,
trajectories: Trajectories,
conditioning: torch.Tensor | None = None,
save_estimator_outputs: bool = False,
save_logprobs: bool = True,
back_steps: TT["batch_shape", torch.int] | None = None,
back_ratio: float | None = None,
metropolis_hastings: bool = False,
**policy_kwargs: Any,
) -> Trajectories:
# K-step backward sampling with the backward estimator,
# where K is the number of backward steps used in https://arxiv.org/abs/2202.01361.
if back_steps is None:
assert (
back_ratio is not None
), "Either kwarg `back_steps` or `back_ratio` must be specified"
K = torch.ceil(back_ratio * trajectories.when_is_done)
else:
K = torch.where(
back_steps > trajectories.when_is_done,
trajectories.when_is_done,
back_steps,
)

import pdb

pdb.set_trace()

partial_trajectories = self.sample_trajectories_backward(
trajectories, n_steps=K
)

### Reconstructing with self.estimator
### TODO

def sample_trajectories(
self,
env: Env,
n: Optional[int] = None,
states: Optional[States] = None,
conditioning: Optional[torch.Tensor] = None,
save_estimator_outputs: bool = False,
save_logprobs: bool = True,
n_local_search_loops: int = 0,
back_steps: TT["batch_shape", torch.int] | None = None,
back_ratio: float | None = None,
metropolis_hastings: bool = False,
**policy_kwargs: Any,
) -> Trajectories:
"""Sample trajectories sequentially with optional local search.
Args:
env: The environment to sample trajectories from.
n: If given, a batch of n_trajectories will be sampled all
starting from the environment's s_0.
states: If given, trajectories would start from such states. Otherwise,
trajectories are sampled from $s_o$ and n_trajectories must be provided.
conditioning: An optional tensor of conditioning information.
save_estimator_outputs: If True, the estimator outputs will be returned. This
is useful for off-policy training with tempered policy.
save_logprobs: If True, calculates and saves the log probabilities of sampled
actions. This is useful for on-policy training.
local_search: If True, applies local search operation.
back_steps: The number of backward steps.
back_ratio: The ratio of the number of backward steps to the length of the trajectory.
metropolis_hastings: If True, applies Metropolis-Hastings acceptance criterion.
policy_kwargs: keyword arguments to be passed to the
`to_probability_distribution` method of the estimator. For example, for
DiscretePolicyEstimators, the kwargs can contain the `temperature`
parameter, `epsilon`, and `sf_bias`. In the continuous case these
kwargs will be user defined. This can be used to, for example, sample
off-policy.
Returns: A Trajectories object representing the batch of sampled trajectories,
where the batch size is n * (1 + n_local_search_loops).
"""

trajectories = super().sample_trajectories(
env,
n,
states,
conditioning,
save_estimator_outputs,
save_logprobs,
**policy_kwargs,
)
all_trajectories = trajectories
for _ in range(n_local_search_loops):
# Search phase
ls_trajectories = self.local_search(
env,
trajectories,
conditioning,
save_estimator_outputs,
save_logprobs,
back_steps,
back_ratio,
metropolis_hastings,
**policy_kwargs,
)
all_trajectories.extend(
ls_trajectories
) # Store all regardless of the acceptance.

# Selection phase
if not metropolis_hastings:
update_indices = trajectories.log_rewards < ls_trajectories.log_rewards
trajectories[update_indices] = ls_trajectories[update_indices]
else: # Metropolis-Hastings acceptance criterion
# TODO: Implement Metropolis-Hastings acceptance criterion.
# We need p(x -> s -> x') = p_B(x -> s) * p_F(s -> x')
# and p(x' -> s' -> x) = p_B(x' -> s') * p_F(s' -> x)
# to calculate the acceptance ratio.
raise NotImplementedError(
"Metropolis-Hastings acceptance criterion is not implemented."
)

return all_trajectories

0 comments on commit d77f82d

Please sign in to comment.