diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 5b235650..d0a97574 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -26,13 +26,21 @@ def __init__( self.estimator = estimator def sample_actions( - self, env: Env, states: States, **policy_kwargs: Optional[dict], + self, + env: Env, + states: States, + save_estimator_outputs : bool = False, + calculate_logprobs : bool = True, + **policy_kwargs: Optional[dict], ) -> Tuple[Actions, TT["batch_shape", torch.float]]: """Samples actions from the given states. Args: env: The environment to sample actions from. - states (States): A batch of states. + states: A batch of states. + save_estimator_outputs: If True, the estimator outputs will be returned. + calculate_logprobs: If True, calculates the log probabilities of sampled + actions. policy_kwargs: keyword arguments to be passed to the `to_probability_distribution` method of the estimator. For example, for DiscretePolicyEstimators, the kwargs can contain the `temperature` @@ -40,6 +48,12 @@ def sample_actions( kwargs will be user defined. This can be used to, for example, sample off-policy. + When sampling off policy, ensure to `save_estimator_outputs` and not + `calculate logprobs`. Log probabilities are instead calculated during the + computation of `PF` as part of the `GFlowNet` class, and the estimator + outputs are required for estimating the logprobs of these off policy + actions. + Returns: A tuple of tensors containing: - An Actions object containing the sampled actions. @@ -47,24 +61,35 @@ def sample_actions( the sampled actions under the probability distribution of the given states. """ - module_output = self.estimator(states) + estimator_output = self.estimator(states) dist = self.estimator.to_probability_distribution( - states, module_output, **policy_kwargs + states, estimator_output, **policy_kwargs ) with torch.no_grad(): actions = dist.sample() - log_probs = dist.log_prob(actions) - if torch.any(torch.isinf(log_probs)): - raise RuntimeError("Log probabilities are inf. This should not happen.") - return env.Actions(actions), log_probs # TODO: return module_output here. + if calculate_logprobs: + log_probs = dist.log_prob(actions) + if torch.any(torch.isinf(log_probs)): + raise RuntimeError("Log probabilities are inf. This should not happen.") + else: + log_probs = None + + actions = env.Actions(actions) + + if not save_estimator_outputs: + estimator_output = None + + return actions, log_probs, estimator_output + def sample_trajectories( self, env: Env, states: Optional[States] = None, n_trajectories: Optional[int] = None, + off_policy: bool = False, **policy_kwargs, ) -> Trajectories: """Sample trajectories sequentially. @@ -75,6 +100,8 @@ def sample_trajectories( trajectories are sampled from $s_o$ and n_trajectories must be provided. n_trajectories: If given, a batch of n_trajectories will be sampled all starting from the environment's s_0. + off_policy: If True, samples actions such that we skip log probability + calculation, and we save the estimator outputs for later use. policy_kwargs: keyword arguments to be passed to the `to_probability_distribution` method of the estimator. For example, for DiscretePolicyEstimators, the kwargs can contain the `temperature` @@ -88,8 +115,6 @@ def sample_trajectories( AssertionError: When both states and n_trajectories are specified. AssertionError: When states are not linear. """ - # TODO: Optionally accept module outputs (this will skip inference using the - # estimator). if states is None: assert ( n_trajectories is not None @@ -122,19 +147,30 @@ def sample_trajectories( ) step = 0 + all_estimator_outputs = [] while not all(dones): actions = env.Actions.make_dummy_actions(batch_shape=(n_trajectories,)) log_probs = torch.full( (n_trajectories,), fill_value=0, dtype=torch.float, device=device ) - # TODO: Retrieve module outputs here, and stack them along the trajectory - # length. - # TODO: Optionally submit module outputs to skip re-estimation. - valid_actions, actions_log_probs = self.sample_actions(env, states[~dones], **policy_kwargs) - actions[~dones] = valid_actions + # This optionally allows you to retrieve the estimator_outputs collected + # during sampling. This is useful if, for example, you want to evaluate off + # policy actions later without repeating calculations to obtain the env + # distribution parameters. + valid_actions, actions_log_probs, estimator_outputs = self.sample_actions( + env, + states[~dones], + save_estimator_outputs=True if off_policy else False, + calculate_logprobs=False if off_policy else True, + **policy_kwargs + ) + if not isinstance(estimator_outputs, type(None)): + all_estimator_outputs.append(estimator_outputs) - log_probs[~dones] = actions_log_probs + actions[~dones] = valid_actions + if not off_policy: # When off_policy, actions_log_probs are None. + log_probs[~dones] = actions_log_probs trajectories_actions += [actions] trajectories_logprobs += [log_probs] @@ -165,6 +201,8 @@ def sample_trajectories( trajectories_states += [states.tensor] + if off_policy: + all_estimator_outputs = torch.stack(all_estimator_outputs, dim=0) trajectories_states = torch.stack(trajectories_states, dim=0) trajectories_states = env.States(tensor=trajectories_states) trajectories_actions = env.Actions.stack(trajectories_actions) @@ -180,4 +218,4 @@ def sample_trajectories( log_probs=trajectories_logprobs, # TODO: Optionally skip computation of logprobs. ) - return trajectories + return trajectories, all_estimator_outputs