diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 5de2d494..6d38bc2d 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -92,20 +92,29 @@ def __init__( if when_is_done is not None else torch.full(size=(0,), fill_value=-1, dtype=torch.long) ) - assert self.when_is_done.shape == (self.n_trajectories,) and self.when_is_done.dtype == torch.long + assert ( + self.when_is_done.shape == (self.n_trajectories,) + and self.when_is_done.dtype == torch.long + ) self._log_rewards = ( log_rewards if log_rewards is not None else torch.full(size=(0,), fill_value=0, dtype=torch.float) ) - assert self._log_rewards.shape == (self.n_trajectories,) and self._log_rewards.dtype == torch.float + assert ( + self._log_rewards.shape == (self.n_trajectories,) + and self._log_rewards.dtype == torch.float + ) - if log_probs is not None: - assert log_probs.shape == (self.max_length, self.n_trajectories) and log_probs.dtype == torch.float + if log_probs is not None and log_probs.shape != (0, 0): + assert ( + log_probs.shape == (self.max_length, self.n_trajectories) + and log_probs.dtype == torch.float + ) else: log_probs = torch.full(size=(0, 0), fill_value=0, dtype=torch.float) - self.log_probs = log_probs + self.log_probs = log_probs self.estimator_outputs = estimator_outputs if self.estimator_outputs is not None: @@ -207,15 +216,13 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories: ) @staticmethod - def extend_log_probs( - log_probs: torch.Tensor, new_max_length: int - ) -> torch.Tensor: + def extend_log_probs(log_probs: torch.Tensor, new_max_length: int) -> torch.Tensor: """Extend the log_probs matrix by adding 0 until the required length is reached. - + Args: log_probs: The log_probs tensor of shape (max_length, n_trajectories) to extend. new_max_length: The new length of the log_probs tensor. - + Returns: The extended log_probs tensor of shape (new_max_length, n_trajectories). """ diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 819620f0..3a8c2d3d 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -39,7 +39,6 @@ def sample_actions( """Samples actions from the given states. Args: - estimator: A GFNModule to pass to the probability distribution calculator. env: The environment to sample actions from. states: A batch of states. conditioning: An optional tensor of conditioning information. @@ -203,11 +202,11 @@ def sample_trajectories( all_estimator_outputs.append(estimator_outputs_padded) actions[~dones] = valid_actions + trajectories_actions.append(actions) if save_logprobs: # When off_policy, actions_log_probs are None. log_probs[~dones] = actions_log_probs - trajectories_actions.append(actions) - trajectories_logprobs.append(log_probs) + trajectories_logprobs.append(log_probs) if self.estimator.is_backward: new_states = env._backward_step(states, actions) @@ -264,3 +263,230 @@ def sample_trajectories( ) return trajectories + + +class LocalSearchSampler(Sampler): + """Sampler equipped with local search capabilities. + The local search operation is based on back-and-forth heuristic, first proposed + by Zhang et al. 2022 (https://arxiv.org/abs/2202.01361) for negative sampling + and further explored its effectiveness in various applications by Kim et al. 2023 + (https://arxiv.org/abs/2310.02710). + + Attributes: + estimator: the submitted PolicyEstimator for the forward pass. + pb_estimator: the PolicyEstimator for the backward pass. + """ + + def __init__(self, estimator: GFNModule, pb_estimator: GFNModule): + super().__init__(estimator) + self.backward_sampler = Sampler(pb_estimator) + + def local_search( + self, + env: Env, + trajectories: Trajectories, + conditioning: torch.Tensor | None = None, + save_estimator_outputs: bool = False, + save_logprobs: bool = True, + back_steps: torch.Tensor | None = None, + back_ratio: float | None = None, + **policy_kwargs: Any, + ) -> Trajectories: + bs = trajectories.n_trajectories + state_shape = trajectories.states.state_shape + action_shape = trajectories.env.action_shape + + # K-step backward sampling with the backward estimator, + # where K is the number of backward steps used in https://arxiv.org/abs/2202.01361. + if back_steps is None: + assert ( + back_ratio is not None + ), "Either kwarg `back_steps` or `back_ratio` must be specified" + K = torch.ceil(back_ratio * (trajectories.when_is_done - 1)).long() + else: + K = torch.where( + back_steps > trajectories.when_is_done, + trajectories.when_is_done, + back_steps, + ) + + # FIXME: There is a bug in the following code... + # Backward masking is not correctly passed through the local search iterations, I guess... + backward_trajectories = self.backward_sampler.sample_trajectories( + env, + states=trajectories.last_states, + conditioning=conditioning, + save_estimator_outputs=save_estimator_outputs, + save_logprobs=save_logprobs, + **policy_kwargs, + ) + # Calculate the forward probability if needed (metropolis-hastings). + if save_logprobs: + raise NotImplementedError("metropolis-hastings is not implemented yet.") + + all_states = backward_trajectories.to_states() + junction_states = all_states[ + torch.arange(bs, device=all_states.device) + bs * K + ] + + ### Reconstructing with self.estimator + recon_trajectories = super().sample_trajectories( + env, + states=junction_states, + conditioning=conditioning, + save_estimator_outputs=save_estimator_outputs, + save_logprobs=save_logprobs, + **policy_kwargs, + ) + # Calculate backward probability if needed (metropolis-hastings). + if save_logprobs: + raise NotImplementedError("metropolis-hastings is not implemented yet.") + + # Obtain full trajectories by concatenating the backward and forward parts. + trajectories_dones = ( + backward_trajectories.when_is_done - K + recon_trajectories.when_is_done + ) + max_traj_len = trajectories_dones.max() + 1 + trajectories_states_tsr = torch.full((max_traj_len, bs, *state_shape), -1).to( + all_states.tensor + ) + trajectories_actions_tsr = torch.full( + (max_traj_len - 1, bs, *action_shape), -1 + ).to(all_states.tensor) + trajectories_log_rewards = ( + recon_trajectories.log_rewards + ) # Episodic reward is assumed. + trajectories_logprobs = None # TODO + + len_recon = recon_trajectories.states.tensor.shape[0] + for i in range(bs): + # Backward part + back_source_idx = backward_trajectories.when_is_done[i] + n_back = back_source_idx - K[i] + trajectories_states_tsr[:n_back, i] = backward_trajectories.states.tensor[ + back_source_idx - n_back + 1 : back_source_idx + 1, i + ].flip(0) + + # FIXME: This is not correct in general... + # Because the action index may not be consistent with the forward pass. + trajectories_actions_tsr[:n_back, i] = backward_trajectories.actions.tensor[ + back_source_idx - n_back : back_source_idx, i + ].flip(0) + + # Forward part + trajectories_states_tsr[n_back : n_back + len_recon, i] = ( + recon_trajectories.states.tensor[:, i] + ) + trajectories_actions_tsr[n_back : n_back + len_recon - 1, i] = ( + recon_trajectories.actions.tensor[: len_recon - 1, i] + ) + if save_logprobs: # concatenate log_probs + raise NotImplementedError("metropolis-hastings is not implemented yet.") + + new_trajectories = Trajectories( + env=env, + states=env.States(trajectories_states_tsr), + conditioning=conditioning, + actions=env.Actions(trajectories_actions_tsr), + when_is_done=trajectories_dones, + is_backward=False, + # FIXME: This is weird... since the trajectory contains + # both backward and forward parts. + # Maybe calculate log_pfs for the backward part -> and set is_backward=True? + log_rewards=trajectories_log_rewards, + log_probs=trajectories_logprobs, # TODO: Support log_probs (`None` for now) + ) + + return new_trajectories + + def sample_trajectories( + self, + env: Env, + n: Optional[int] = None, + states: Optional[States] = None, + conditioning: Optional[torch.Tensor] = None, + save_estimator_outputs: bool = False, # FIXME: currently not work when this is True + save_logprobs: bool = True, # TODO: Support save_logprobs=True + n_local_search_loops: int = 0, + back_steps: torch.Tensor | None = None, + back_ratio: float | None = None, + use_metropolis_hastings: bool = False, + **policy_kwargs: Any, + ) -> Trajectories: + """Sample trajectories sequentially with optional local search. + + Args: + env: The environment to sample trajectories from. + n: If given, a batch of n_trajectories will be sampled all + starting from the environment's s_0. + states: If given, trajectories would start from such states. Otherwise, + trajectories are sampled from $s_o$ and n_trajectories must be provided. + conditioning: An optional tensor of conditioning information. + save_estimator_outputs: If True, the estimator outputs will be returned. This + is useful for off-policy training with tempered policy. + save_logprobs: If True, calculates and saves the log probabilities of sampled + actions. This is useful for on-policy training. + local_search: If True, applies local search operation. + back_steps: The number of backward steps. + back_ratio: The ratio of the number of backward steps to the length of the trajectory. + use_metropolis_hastings: If True, applies Metropolis-Hastings acceptance criterion. + policy_kwargs: keyword arguments to be passed to the + `to_probability_distribution` method of the estimator. For example, for + DiscretePolicyEstimators, the kwargs can contain the `temperature` + parameter, `epsilon`, and `sf_bias`. In the continuous case these + kwargs will be user defined. This can be used to, for example, sample + off-policy. + + Returns: A Trajectories object representing the batch of sampled trajectories, + where the batch size is n * (1 + n_local_search_loops). + """ + + trajectories = super().sample_trajectories( + env, + n, + states, + conditioning, + save_estimator_outputs, + save_logprobs or use_metropolis_hastings, + **policy_kwargs, + ) + + if n is None: + n = trajectories.n_trajectories + + search_indices = torch.arange(n, device=trajectories.states.device) + for it in range(n_local_search_loops): + # Search phase + ls_trajectories = self.local_search( + env, + trajectories[search_indices], + conditioning, + save_estimator_outputs, + save_logprobs or use_metropolis_hastings, + back_steps, + back_ratio, + **policy_kwargs, + ) + trajectories.extend( + ls_trajectories + ) # Store all regardless of the acceptance. + + # Selection phase + if not use_metropolis_hastings: + last_indices = torch.arange( + n * (it + 1), n * (it + 2), device=trajectories.states.device + ) + prev_log_rewards = trajectories.log_rewards[search_indices] + new_log_rewards = ls_trajectories.log_rewards + update_indices = prev_log_rewards <= new_log_rewards + search_indices[update_indices] = last_indices[update_indices] + else: # Metropolis-Hastings acceptance criterion + # TODO: Implement Metropolis-Hastings acceptance criterion. + # We need p(x -> s -> x') = p_B(x -> s) * p_F(s -> x') + # and p(x' -> s' -> x) = p_B(x' -> s') * p_F(s' -> x) + # to calculate the acceptance ratio. + raise NotImplementedError( + "Metropolis-Hastings acceptance criterion is not implemented." + ) + + return trajectories diff --git a/tutorials/examples/train_hypergrid_simple.py b/tutorials/examples/train_hypergrid_simple.py index e2058ade..5536aac3 100644 --- a/tutorials/examples/train_hypergrid_simple.py +++ b/tutorials/examples/train_hypergrid_simple.py @@ -41,8 +41,7 @@ def main(args): sampler = Sampler(estimator=pf_estimator) # Move the gflownet to the GPU. - if torch.cuda.is_available(): - gflownet = gflownet.to("cuda") + gflownet = gflownet.to(device_str) # Policy parameters have their own LR. Log Z gets dedicated learning rate # (typically higher). diff --git a/tutorials/examples/train_hypergrid_simple_ls.py b/tutorials/examples/train_hypergrid_simple_ls.py new file mode 100644 index 00000000..753b39ca --- /dev/null +++ b/tutorials/examples/train_hypergrid_simple_ls.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python +import argparse + +import torch +from tqdm import tqdm + +from gfn.gflownet import TBGFlowNet +from gfn.gym import HyperGrid +from gfn.modules import DiscretePolicyEstimator +from gfn.samplers import LocalSearchSampler +from gfn.utils.common import set_seed +from gfn.utils.modules import MLP + + +def main(args): + set_seed(args.seed) + device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" + + # Setup the Environment. + env = HyperGrid(ndim=args.ndim, height=args.height, device_str=device_str) + + # Build the GFlowNet. + module_PF = MLP( + input_dim=env.preprocessor.output_dim, + output_dim=env.n_actions, + ) + module_PB = MLP( + input_dim=env.preprocessor.output_dim, + output_dim=env.n_actions - 1, + trunk=module_PF.trunk, + ) + pf_estimator = DiscretePolicyEstimator( + module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor + ) + pb_estimator = DiscretePolicyEstimator( + module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor + ) + gflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator, logZ=0.0) + + # Feed pf to the sampler. + sampler = LocalSearchSampler(estimator=pf_estimator, pb_estimator=pb_estimator) + + # Move the gflownet to the GPU. + gflownet = gflownet.to(device_str) + + # Policy parameters have their own LR. Log Z gets dedicated learning rate + # (typically higher). + optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=args.lr) + optimizer.add_param_group( + {"params": gflownet.logz_parameters(), "lr": args.lr_logz} + ) + + for i in (pbar := tqdm(range(args.n_iterations))): + trajectories = sampler.sample_trajectories( + env, + n=(args.batch_size // args.n_local_search_loops), + save_logprobs=False, + save_estimator_outputs=False, + epsilon=args.epsilon, + n_local_search_loops=args.n_local_search_loops, + back_ratio=0.5, + ) + optimizer.zero_grad() + loss = gflownet.loss(env, trajectories) + loss.backward() + optimizer.step() + pbar.set_postfix({"loss": loss.item()}) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--no_cuda", action="store_true", help="Prevent CUDA usage") + parser.add_argument( + "--ndim", type=int, default=2, help="Number of dimensions in the environment" + ) + parser.add_argument( + "--height", type=int, default=16, help="Height of the environment" + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument( + "--lr", + type=float, + default=1e-3, + help="Learning rate for the estimators' modules", + ) + parser.add_argument( + "--lr_logz", + type=float, + default=1e-1, + help="Learning rate for the logZ parameter", + ) + parser.add_argument( + "--n_iterations", type=int, default=1000, help="Number of iterations" + ) + parser.add_argument("--batch_size", type=int, default=16, help="Batch size") + parser.add_argument( + "--epsilon", type=float, default=0.1, help="Epsilon for the sampler" + ) + + # Local search parameters. + parser.add_argument( + "--n_local_search_loops", + type=int, + default=4, + help="Number of local search loops", + ) + parser.add_argument( + "--back_ratio", + type=float, + default=0.5, + help="The ratio of the number of backward steps to the length of the trajectory", + ) + + args = parser.parse_args() + + main(args)