Skip to content

Commit

Permalink
estimator outputs can be saved
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Nov 21, 2023
1 parent 450ebf0 commit a8b637e
Showing 1 changed file with 55 additions and 17 deletions.
72 changes: 55 additions & 17 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,45 +26,70 @@ def __init__(
self.estimator = estimator

def sample_actions(
self, env: Env, states: States, **policy_kwargs: Optional[dict],
self,
env: Env,
states: States,
save_estimator_outputs : bool = False,
calculate_logprobs : bool = True,
**policy_kwargs: Optional[dict],
) -> Tuple[Actions, TT["batch_shape", torch.float]]:
"""Samples actions from the given states.
Args:
env: The environment to sample actions from.
states (States): A batch of states.
states: A batch of states.
save_estimator_outputs: If True, the estimator outputs will be returned.
calculate_logprobs: If True, calculates the log probabilities of sampled
actions.
policy_kwargs: keyword arguments to be passed to the
`to_probability_distribution` method of the estimator. For example, for
DiscretePolicyEstimators, the kwargs can contain the `temperature`
parameter, `epsilon`, and `sf_bias`. In the continuous case these
kwargs will be user defined. This can be used to, for example, sample
off-policy.
When sampling off policy, ensure to `save_estimator_outputs` and not
`calculate logprobs`. Log probabilities are instead calculated during the
computation of `PF` as part of the `GFlowNet` class, and the estimator
outputs are required for estimating the logprobs of these off policy
actions.
Returns:
A tuple of tensors containing:
- An Actions object containing the sampled actions.
- A tensor of shape (*batch_shape,) containing the log probabilities of
the sampled actions under the probability distribution of the given
states.
"""
module_output = self.estimator(states)
estimator_output = self.estimator(states)
dist = self.estimator.to_probability_distribution(
states, module_output, **policy_kwargs
states, estimator_output, **policy_kwargs
)

with torch.no_grad():
actions = dist.sample()
log_probs = dist.log_prob(actions)
if torch.any(torch.isinf(log_probs)):
raise RuntimeError("Log probabilities are inf. This should not happen.")

return env.Actions(actions), log_probs # TODO: return module_output here.
if calculate_logprobs:
log_probs = dist.log_prob(actions)
if torch.any(torch.isinf(log_probs)):
raise RuntimeError("Log probabilities are inf. This should not happen.")
else:
log_probs = None

actions = env.Actions(actions)

if not save_estimator_outputs:
estimator_output = None

return actions, log_probs, estimator_output


def sample_trajectories(
self,
env: Env,
states: Optional[States] = None,
n_trajectories: Optional[int] = None,
off_policy: bool = False,
**policy_kwargs,
) -> Trajectories:
"""Sample trajectories sequentially.
Expand All @@ -75,6 +100,8 @@ def sample_trajectories(
trajectories are sampled from $s_o$ and n_trajectories must be provided.
n_trajectories: If given, a batch of n_trajectories will be sampled all
starting from the environment's s_0.
off_policy: If True, samples actions such that we skip log probability
calculation, and we save the estimator outputs for later use.
policy_kwargs: keyword arguments to be passed to the
`to_probability_distribution` method of the estimator. For example, for
DiscretePolicyEstimators, the kwargs can contain the `temperature`
Expand All @@ -88,8 +115,6 @@ def sample_trajectories(
AssertionError: When both states and n_trajectories are specified.
AssertionError: When states are not linear.
"""
# TODO: Optionally accept module outputs (this will skip inference using the
# estimator).
if states is None:
assert (
n_trajectories is not None
Expand Down Expand Up @@ -122,19 +147,30 @@ def sample_trajectories(
)

step = 0
all_estimator_outputs = []

while not all(dones):
actions = env.Actions.make_dummy_actions(batch_shape=(n_trajectories,))
log_probs = torch.full(
(n_trajectories,), fill_value=0, dtype=torch.float, device=device
)
# TODO: Retrieve module outputs here, and stack them along the trajectory
# length.
# TODO: Optionally submit module outputs to skip re-estimation.
valid_actions, actions_log_probs = self.sample_actions(env, states[~dones], **policy_kwargs)
actions[~dones] = valid_actions
# This optionally allows you to retrieve the estimator_outputs collected
# during sampling. This is useful if, for example, you want to evaluate off
# policy actions later without repeating calculations to obtain the env
# distribution parameters.
valid_actions, actions_log_probs, estimator_outputs = self.sample_actions(
env,
states[~dones],
save_estimator_outputs=True if off_policy else False,
calculate_logprobs=False if off_policy else True,
**policy_kwargs
)
if not isinstance(estimator_outputs, type(None)):
all_estimator_outputs.append(estimator_outputs)

log_probs[~dones] = actions_log_probs
actions[~dones] = valid_actions
if not off_policy: # When off_policy, actions_log_probs are None.
log_probs[~dones] = actions_log_probs
trajectories_actions += [actions]
trajectories_logprobs += [log_probs]

Expand Down Expand Up @@ -165,6 +201,8 @@ def sample_trajectories(

trajectories_states += [states.tensor]

if off_policy:
all_estimator_outputs = torch.stack(all_estimator_outputs, dim=0)
trajectories_states = torch.stack(trajectories_states, dim=0)
trajectories_states = env.States(tensor=trajectories_states)
trajectories_actions = env.Actions.stack(trajectories_actions)
Expand All @@ -180,4 +218,4 @@ def sample_trajectories(
log_probs=trajectories_logprobs, # TODO: Optionally skip computation of logprobs.
)

return trajectories
return trajectories, all_estimator_outputs

0 comments on commit a8b637e

Please sign in to comment.