Skip to content

Commit

Permalink
Merge pull request #147 from GFNOrg/rethinking_sampling
Browse files Browse the repository at this point in the history
Rethinking sampling
  • Loading branch information
josephdviviano authored Feb 16, 2024
2 parents 68fda28 + e7c7453 commit eedc7e8
Show file tree
Hide file tree
Showing 29 changed files with 1,816 additions and 270 deletions.
20 changes: 8 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,10 @@ from gfn.utils import NeuralNet # NeuralNet is a simple multi-layer perceptron

if __name__ == "__main__":

# 1 - We define the environment

env = HyperGrid(ndim=4, height=8, R0=0.01) # Grid of size 8x8x8x8

# 2 - We define the needed modules (neural networks)
# 1 - We define the environment.
env = HyperGrid(ndim=4, height=8, R0=0.01) # Grid of size 8x8x8x8

# 2 - We define the needed modules (neural networks).
# The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator
module_PF = NeuralNet(
input_dim=env.preprocessor.output_dim,
Expand All @@ -88,17 +86,14 @@ if __name__ == "__main__":
torso=module_PF.torso # We share all the parameters of P_F and P_B, except for the last layer
)

# 3 - We define the estimators

# 3 - We define the estimators.
pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor)
pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor)

# 4 - We define the GFlowNet

# 4 - We define the GFlowNet.
gfn = TBGFlowNet(init_logZ=0., pf=pf_estimator, pb=pb_estimator) # We initialize logZ to 0

# 5 - We define the sampler and the optimizer

# 5 - We define the sampler and the optimizer.
sampler = Sampler(estimator=pf_estimator) # We use an on-policy sampler, based on the forward policy

# Policy parameters have their own LR.
Expand All @@ -110,7 +105,6 @@ if __name__ == "__main__":
optimizer.add_param_group({"params": logz_params, "lr": 1e-1})

# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration

for i in (pbar := tqdm(range(1000))):
trajectories = sampler.sample_trajectories(env=env, n_trajectories=16)
optimizer.zero_grad()
Expand Down Expand Up @@ -193,6 +187,8 @@ Training GFlowNets requires one or multiple estimators, called `GFNModule`s, whi

For non-discrete environments, the user needs to specify their own policies $P_F$ and $P_B$. The module, taking as input a batch of states (as a `States`) object, should return the batched parameters of a `torch.Distribution`. The distribution depends on the environment. The `to_probability_distribution` function handles the conversion of the parameter outputs to an actual batched `Distribution` object, that implements at least the `sample` and `log_prob` functions. An example is provided [here](https://github.com/saleml/torchgfn/tree/master/src/gfn/gym/helpers/box_utils.py), for a square environment in which the forward policy has support either on a quarter disk, or on an arc-circle, such that the angle, and the radius (for the quarter disk part) are scaled samples from a mixture of Beta distributions. The provided example shows an intricate scenario, and it is not expected that user defined environment need this much level of details.

In general, (and perhaps obviously) the `to_probability_distribution` method is used to calculate a probability distribution from a policy. Therefore, in order to go off-policy, one needs to modify the computations in this method during sampling. One accomplishes this using `policy_kwargs`, a `dict` of kwarg-value pairs which are used by the `Estimator` when calculating the new policy. In the discrete case, where common settings apply, one can see their use in `DiscretePolicyEstimator`'s `to_probability_distribution` method by passing a softmax `temperature`, `sf_bias` (a scalar to subtract from the exit action logit) or `epsilon` which allows for e-greedy style exploration. In the continuous case, it is not possible to forsee the methods used for off-policy exploration (as it depends on the details of the `to_probability_distribution` method, which is not generic for continuous GFNs), so this must be handled by the user, using custom `policy_kwargs`.

In all `GFNModule`s, note that the input of the `forward` function is a `States` object. Meaning that they first need to be transformed to tensors. However, `states.tensor` does not necessarily include the structure that a neural network can used to generalize. It is common in these scenarios to have a function that transforms these raw tensor states to ones where the structure is clearer, via a `Preprocessor` object, that is part of the environment. More on this [here](https://github.com/saleml/torchgfn/tree/master/tutorials/ENV.md). The default preprocessor of an environment is the identity preprocessor. The `forward` pass thus first calls the `preprocessor` attribute of the environment on `States`, before performing any transformation. The `preprocessor` is thus an attribute of the module. If it is not explicitly defined, it is set to the identity preprocessor.

For discrete environments, a `Tabular` module is provided, where a lookup table is used instead of a neural network. Additionally, a `UniformPB` module is provided, implementing a uniform backward policy. These modules are provided [here](https://github.com/saleml/torchgfn/tree/master/src/gfn/utils/modules.py).
Expand Down
111 changes: 104 additions & 7 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,20 @@
from gfn.env import Env
from gfn.states import States

import numpy as np
import torch
from torch import Tensor
from torchtyping import TensorType as TT

from gfn.containers.base import Container
from gfn.containers.transitions import Transitions


def is_tensor(t) -> bool:
"""Checks whether t is a torch.Tensor instance."""
return isinstance(t, Tensor)


# TODO: remove env from this class?
class Trajectories(Container):
"""Container for complete trajectories (starting in $s_0$ and ending in $s_f$).
Expand Down Expand Up @@ -47,16 +54,21 @@ def __init__(
is_backward: bool = False,
log_rewards: TT["n_trajectories", torch.float] | None = None,
log_probs: TT["max_length", "n_trajectories", torch.float] | None = None,
estimator_outputs: torch.Tensor | None = None,
) -> None:
"""
Args:
env: The environment in which the trajectories are defined.
states: The states of the trajectories. Defaults to None.
actions: The actions of the trajectories. Defaults to None.
when_is_done: The time step at which each trajectory ends. Defaults to None.
is_backward: Whether the trajectories are backward or forward. Defaults to False.
log_rewards: The log_rewards of the trajectories. Defaults to None.
log_probs: The log probabilities of the trajectories' actions. Defaults to None.
states: The states of the trajectories.
actions: The actions of the trajectories.
when_is_done: The time step at which each trajectory ends.
is_backward: Whether the trajectories are backward or forward.
log_rewards: The log_rewards of the trajectories.
log_probs: The log probabilities of the trajectories' actions.
estimator_outputs: When forward sampling off-policy for an n-step
trajectory, n forward passes will be made on some function approximator,
which may need to be re-used (for example, for evaluating PF). To avoid
duplicated effort, the outputs of the forward passes can be stored here.
If states is None, then the states are initialized to an empty States object,
that can be populated on the fly. If log_rewards is None, then `env.log_reward`
Expand Down Expand Up @@ -87,6 +99,7 @@ def __init__(
if log_probs is not None
else torch.full(size=(0, 0), fill_value=0, dtype=torch.float)
)
self.estimator_outputs = estimator_outputs

def __repr__(self) -> str:
states = self.states.tensor.transpose(0, 1)
Expand Down Expand Up @@ -154,6 +167,21 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories:
log_rewards = (
self._log_rewards[index] if self._log_rewards is not None else None
)
if is_tensor(self.estimator_outputs):
# TODO: Is there a safer way to index self.estimator_outputs for
# for n-dimensional estimator outputs?
#
# First we index along the first dimension of the estimator outputs.
# This can be thought of as the instance dimension, and is
# compatible with all supported indexing approaches (dim=1).
# All dims > 1 are not explicitly indexed unless the dimensionality
# of `index` matches all dimensions of `estimator_outputs` aside
# from the first (trajectory) dimension.
estimator_outputs = self.estimator_outputs[:, index]
# Next we index along the trajectory length (dim=0)
estimator_outputs = estimator_outputs[:new_max_length]
else:
estimator_outputs = None

return Trajectories(
env=self.env,
Expand All @@ -163,6 +191,7 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories:
is_backward=self.is_backward,
log_rewards=log_rewards,
log_probs=log_probs,
estimator_outputs=estimator_outputs,
)

@staticmethod
Expand Down Expand Up @@ -198,7 +227,10 @@ def extend(self, other: Trajectories) -> None:
Args:
other: an external set of Trajectories.
"""
if len(other) == 0:
return

# TODO: The replay buffer is storing `dones` - this wastes a lot of space.
self.actions.extend(other.actions)
self.states.extend(other.states)
self.when_is_done = torch.cat((self.when_is_done, other.when_is_done), dim=0)
Expand All @@ -213,11 +245,76 @@ def extend(self, other: Trajectories) -> None:

if self._log_rewards is not None and other._log_rewards is not None:
self._log_rewards = torch.cat(
(self._log_rewards, other._log_rewards), dim=0
(self._log_rewards, other._log_rewards),
dim=0,
)
else:
self._log_rewards = None

# Either set, or append, estimator outputs if they exist in the submitted
# trajectory.
if self.estimator_outputs is None and is_tensor(other.estimator_outputs):
self.estimator_outputs = other.estimator_outputs
elif is_tensor(self.estimator_outputs) and is_tensor(other.estimator_outputs):
batch_shape = self.actions.batch_shape
n_bs = len(batch_shape)
output_dtype = self.estimator_outputs.dtype

if n_bs == 1:
# Concatenate along the only batch dimension.
self.estimator_outputs = torch.cat(
(self.estimator_outputs, other.estimator_outputs),
dim=0,
)
elif n_bs == 2:
if self.estimator_outputs.shape[0] != other.estimator_outputs.shape[0]:
# First we need to pad the first dimension on either self or other.
self_shape = np.array(self.estimator_outputs.shape)
other_shape = np.array(other.estimator_outputs.shape)
required_first_dim = max(self_shape[0], other_shape[0])

# TODO: This should be a single reused function (#154)
# The size of self needs to grow to match other along dim=0.
if self_shape[0] < other_shape[0]:
pad_dim = required_first_dim - self_shape[0]
pad_dim_full = (pad_dim,) + tuple(self_shape[1:])
output_padding = torch.full(
pad_dim_full,
fill_value=-float("inf"),
dtype=self.estimator_outputs.dtype, # TODO: This isn't working! Hence the cast below...
device=self.estimator_outputs.device,
)
self.estimator_outputs = torch.cat(
(self.estimator_outputs, output_padding),
dim=0,
)

# The size of other needs to grow to match self along dim=0.
if other_shape[0] < self_shape[0]:
pad_dim = required_first_dim - other_shape[0]
pad_dim_full = (pad_dim,) + tuple(other_shape[1:])
output_padding = torch.full(
pad_dim_full,
fill_value=-float("inf"),
dtype=other.estimator_outputs.dtype, # TODO: This isn't working! Hence the cast below...
device=other.estimator_outputs.device,
)
other.estimator_outputs = torch.cat(
(other.estimator_outputs, output_padding),
dim=0,
)

# Concatenate the tensors along the second dimension.
self.estimator_outputs = torch.cat(
(self.estimator_outputs, other.estimator_outputs),
dim=1,
).to(
dtype=output_dtype
) # Cast to prevent single precision becoming double precision... weird.

# Sanity check. TODO: Remove?
assert self.estimator_outputs.shape[:n_bs] == batch_shape

def to_transitions(self) -> Transitions:
"""Returns a `Transitions` object from the trajectories."""
states = self.states[:-1][~self.actions.is_dummy]
Expand Down
26 changes: 9 additions & 17 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Optional, Tuple, Union

import torch
Expand All @@ -8,6 +7,7 @@
from gfn.actions import Actions
from gfn.preprocessors import IdentityPreprocessor, Preprocessor
from gfn.states import DiscreteStates, States
from gfn.utils.common import set_seed

# Errors
NonValidActionsError = type("NonValidActionsError", (ValueError,), {})
Expand All @@ -23,7 +23,6 @@ def __init__(
sf: Optional[TT["state_shape", torch.float]] = None,
device_str: Optional[str] = None,
preprocessor: Optional[Preprocessor] = None,
log_reward_clip: Optional[float] = -100.0,
):
"""Initializes an environment.
Expand All @@ -37,7 +36,6 @@ def __init__(
preprocessor: a Preprocessor object that converts raw states to a tensor
that can be fed into a neural network. Defaults to None, in which case
the IdentityPreprocessor is used.
log_reward_clip: Used to clip small rewards (in particular, log(0) rewards).
"""
self.device = torch.device(device_str) if device_str is not None else s0.device

Expand All @@ -58,7 +56,6 @@ def __init__(

self.preprocessor = preprocessor
self.is_discrete = False
self.log_reward_clip = log_reward_clip

@abstractmethod
def make_States_class(self) -> type[States]:
Expand All @@ -83,7 +80,7 @@ def reset(
assert not (random and sink)

if random and seed is not None:
torch.manual_seed(seed)
set_seed(seed, performance_mode=True)

if batch_shape is None:
batch_shape = (1,)
Expand All @@ -94,15 +91,15 @@ def reset(
)

@abstractmethod
def maskless_step(
def maskless_step( # TODO: rename to step, other method becomes _step.
self, states: States, actions: Actions
) -> TT["batch_shape", "state_shape", torch.float]:
"""Function that takes a batch of states and actions and returns a batch of next
states. Does not need to check whether the actions are valid or the states are sink states.
"""

@abstractmethod
def maskless_backward_step(
def maskless_backward_step( # TODO: rename to backward_step, other method becomes _backward_step.
self, states: States, actions: Actions
) -> TT["batch_shape", "state_shape", torch.float]:
"""Function that takes a batch of states and actions and returns a batch of previous
Expand Down Expand Up @@ -134,7 +131,7 @@ def step(
) -> States:
"""Function that takes a batch of states and actions and returns a batch of next
states and a boolean tensor indicating sink states in the new batch."""
new_states = deepcopy(states)
new_states = states.clone() # TODO: Ensure this is efficient!
valid_states_idx: TT["batch_shape", torch.bool] = ~states.is_sink_state
valid_actions = actions[valid_states_idx]
valid_states = states[valid_states_idx]
Expand All @@ -154,8 +151,6 @@ def step(
new_not_done_states_tensor = self.maskless_step(
not_done_states, not_done_actions
)
# if isinstance(new_states, DiscreteStates):
# new_not_done_states.masks = self.update_masks(not_done_states, not_done_actions)

new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor

Expand All @@ -168,7 +163,7 @@ def backward_step(
) -> States:
"""Function that takes a batch of states and actions and returns a batch of next
states and a boolean tensor indicating initial states in the new batch."""
new_states = deepcopy(states)
new_states = states.clone() # TODO: Ensure this is efficient!
valid_states_idx: TT["batch_shape", torch.bool] = ~new_states.is_initial_state
valid_actions = actions[valid_states_idx]
valid_states = states[valid_states_idx]
Expand Down Expand Up @@ -197,8 +192,8 @@ def reward(self, final_states: States) -> TT["batch_shape", torch.float]:
raise NotImplementedError("Reward function is not implemented.")

def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]:
"""Calculates the log reward (clipping small rewards)."""
return torch.log(self.reward(final_states)).clip(self.log_reward_clip)
"""Calculates the log reward."""
return torch.log(self.reward(final_states))

@property
def log_partition(self) -> float:
Expand All @@ -224,7 +219,6 @@ def __init__(
sf: Optional[TT["state_shape", torch.float]] = None,
device_str: Optional[str] = None,
preprocessor: Optional[Preprocessor] = None,
log_reward_clip: Optional[float] = -100.0,
):
"""Initializes a discrete environment.
Expand All @@ -234,12 +228,10 @@ def __init__(
sf: The final state tensor (shared among all trajectories).
device_str: String representation of a torch.device.
preprocessor: An optional preprocessor for intermediate states.
log_reward_clip: Used to clip small rewards (in particular, log(0) rewards).
"""
self.n_actions = n_actions
super().__init__(s0, sf, device_str, preprocessor, log_reward_clip)
super().__init__(s0, sf, device_str, preprocessor)
self.is_discrete = True
self.log_reward_clip = log_reward_clip

def make_Actions_class(self) -> type[Actions]:
env = self
Expand Down
Loading

0 comments on commit eedc7e8

Please sign in to comment.