diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index e2f5247b..b0d3070d 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -44,7 +44,6 @@ def sample_actions( """Samples actions from the given states. Args: - estimator: A GFNModule to pass to the probability distribution calculator. env: The environment to sample actions from. states: A batch of states. conditioning: An optional tensor of conditioning information. @@ -206,11 +205,11 @@ def sample_trajectories( all_estimator_outputs.append(estimator_outputs_padded) actions[~dones] = valid_actions + trajectories_actions.append(actions) if save_logprobs: # When off_policy, actions_log_probs are None. log_probs[~dones] = actions_log_probs - trajectories_actions.append(actions) - trajectories_logprobs.append(log_probs) + trajectories_logprobs.append(log_probs) if self.estimator.is_backward: new_states = env._backward_step(states, actions) @@ -267,3 +266,222 @@ def sample_trajectories( ) return trajectories + + +class LocalSearchSampler(Sampler): + """Sampler equipped with local search capabilities. + The local search operation is based on back-and-forth heuristic, first proposed + by Zhang et al. 2022 (https://arxiv.org/abs/2202.01361) for negative sampling + and further explored its effectiveness in various applications by Kim et al. 2023 + (https://arxiv.org/abs/2310.02710). + + Attributes: + estimator: the submitted PolicyEstimator for the forward pass. + pb_estimator: the PolicyEstimator for the backward pass. + """ + + def __init__(self, estimator: GFNModule, pb_estimator: GFNModule): + super().__init__(estimator) + self.pb_estimator = pb_estimator + + def sample_actions_backward( + self, + env: Env, + states: States, + conditioning: torch.Tensor | None = None, + save_estimator_outputs: bool = False, + save_logprobs: bool = True, + **policy_kwargs: Any, + ) -> Tuple[ + Actions, + TT["batch_shape", torch.float] | None, + TT["batch_shape", torch.float] | None, + ]: + """Samples backward actions from the given states. + + Args: + env: The environment to sample actions from. + states: A batch of states. + conditioning: An optional tensor of conditioning information. + save_estimator_outputs: If True, the estimator outputs will be returned. + save_logprobs: If True, calculates and saves the log probabilities of sampled + actions. + policy_kwargs: keyword arguments to be passed to the + `to_probability_distribution` method of the estimator. For example, for + DiscretePolicyEstimators, the kwargs can contain the `temperature` + parameter, `epsilon`, and `sf_bias`. In the continuous case these + kwargs will be user defined. This can be used to, for example, sample + off-policy. + + When sampling off policy, ensure to `save_estimator_outputs` and not + `calculate logprobs`. Log probabilities are instead calculated during the + computation of `PF` as part of the `GFlowNet` class, and the estimator + outputs are required for estimating the logprobs of these off policy + actions. + + Returns: + A tuple of tensors containing: + - An Actions object containing the sampled actions. + - A tensor of shape (*batch_shape,) containing the log probabilities of + the sampled actions under the probability distribution of the given + states. + """ + if conditioning is not None: + with has_conditioning_exception_handler("pb_estimator", self.pb_estimator): + pb_estimator_output = self.pb_estimator(states, conditioning) + else: + with no_conditioning_exception_handler("pb_estimator", self.pb_estimator): + pb_estimator_output = self.pb_estimator(states) + + dist = self.pb_estimator.to_probability_distribution( + states, pb_estimator_output, **policy_kwargs + ) + + with torch.no_grad(): + actions = dist.sample() + + if save_logprobs: + log_probs = dist.log_prob(actions) + if torch.any(torch.isinf(log_probs)): + raise RuntimeError("Log probabilities are inf. This should not happen.") + else: + log_probs = None + + actions = env.actions_from_tensor(actions) + + if not save_estimator_outputs: + pb_estimator_output = None + + return actions, log_probs, pb_estimator + + def sample_trajectories_backward( + self, trajectories: Trajectories, n_steps: TT["batch_shape", torch.int] + ) -> List[States]: + """ + Args: + trajectories: A batch of trajectories. + n_steps: The number of steps to backtrack. + """ + partial_trajectories = trajectories # TODO + return partial_trajectories + + def local_search( + self, + env: Env, + trajectories: Trajectories, + conditioning: torch.Tensor | None = None, + save_estimator_outputs: bool = False, + save_logprobs: bool = True, + back_steps: TT["batch_shape", torch.int] | None = None, + back_ratio: float | None = None, + metropolis_hastings: bool = False, + **policy_kwargs: Any, + ) -> Trajectories: + # K-step backward sampling with the backward estimator, + # where K is the number of backward steps used in https://arxiv.org/abs/2202.01361. + if back_steps is None: + assert ( + back_ratio is not None + ), "Either kwarg `back_steps` or `back_ratio` must be specified" + K = torch.ceil(back_ratio * trajectories.when_is_done) + else: + K = torch.where( + back_steps > trajectories.when_is_done, + trajectories.when_is_done, + back_steps, + ) + + import pdb + + pdb.set_trace() + + partial_trajectories = self.sample_trajectories_backward( + trajectories, n_steps=K + ) + + ### Reconstructing with self.estimator + ### TODO + + def sample_trajectories( + self, + env: Env, + n: Optional[int] = None, + states: Optional[States] = None, + conditioning: Optional[torch.Tensor] = None, + save_estimator_outputs: bool = False, + save_logprobs: bool = True, + n_local_search_loops: int = 0, + back_steps: TT["batch_shape", torch.int] | None = None, + back_ratio: float | None = None, + metropolis_hastings: bool = False, + **policy_kwargs: Any, + ) -> Trajectories: + """Sample trajectories sequentially with optional local search. + + Args: + env: The environment to sample trajectories from. + n: If given, a batch of n_trajectories will be sampled all + starting from the environment's s_0. + states: If given, trajectories would start from such states. Otherwise, + trajectories are sampled from $s_o$ and n_trajectories must be provided. + conditioning: An optional tensor of conditioning information. + save_estimator_outputs: If True, the estimator outputs will be returned. This + is useful for off-policy training with tempered policy. + save_logprobs: If True, calculates and saves the log probabilities of sampled + actions. This is useful for on-policy training. + local_search: If True, applies local search operation. + back_steps: The number of backward steps. + back_ratio: The ratio of the number of backward steps to the length of the trajectory. + metropolis_hastings: If True, applies Metropolis-Hastings acceptance criterion. + policy_kwargs: keyword arguments to be passed to the + `to_probability_distribution` method of the estimator. For example, for + DiscretePolicyEstimators, the kwargs can contain the `temperature` + parameter, `epsilon`, and `sf_bias`. In the continuous case these + kwargs will be user defined. This can be used to, for example, sample + off-policy. + + Returns: A Trajectories object representing the batch of sampled trajectories, + where the batch size is n * (1 + n_local_search_loops). + """ + + trajectories = super().sample_trajectories( + env, + n, + states, + conditioning, + save_estimator_outputs, + save_logprobs, + **policy_kwargs, + ) + all_trajectories = trajectories + for _ in range(n_local_search_loops): + # Search phase + ls_trajectories = self.local_search( + env, + trajectories, + conditioning, + save_estimator_outputs, + save_logprobs, + back_steps, + back_ratio, + metropolis_hastings, + **policy_kwargs, + ) + all_trajectories.extend( + ls_trajectories + ) # Store all regardless of the acceptance. + + # Selection phase + if not metropolis_hastings: + update_indices = trajectories.log_rewards < ls_trajectories.log_rewards + trajectories[update_indices] = ls_trajectories[update_indices] + else: # Metropolis-Hastings acceptance criterion + # TODO: Implement Metropolis-Hastings acceptance criterion. + # We need p(x -> s -> x') = p_B(x -> s) * p_F(s -> x') + # and p(x' -> s' -> x) = p_B(x' -> s') * p_F(s' -> x) + # to calculate the acceptance ratio. + raise NotImplementedError( + "Metropolis-Hastings acceptance criterion is not implemented." + ) + + return all_trajectories