Skip to content

Commit

Permalink
Merge pull request #203 from GFNOrg/hyeok9855/minor-refactorings
Browse files Browse the repository at this point in the history
Minor refactorings and Rename NeuralNet to MLP
  • Loading branch information
hyeok9855 authored Oct 29, 2024
2 parents b3bae95 + 5a4198e commit 57cc269
Show file tree
Hide file tree
Showing 23 changed files with 576 additions and 586 deletions.
22 changes: 11 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,19 @@ from gfn.gflownet import TBGFlowNet
from gfn.gym import HyperGrid # We use the hyper grid environment
from gfn.modules import DiscretePolicyEstimator
from gfn.samplers import Sampler
from gfn.utils import NeuralNet # NeuralNet is a simple multi-layer perceptron (MLP)
from gfn.utils.modules import MLP # is a simple multi-layer perceptron (MLP)

# 1 - We define the environment.
env = HyperGrid(ndim=4, height=8, R0=0.01) # Grid of size 8x8x8x8

# 2 - We define the needed modules (neural networks).
# The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator
module_PF = NeuralNet(
module_PF = MLP(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions
) # Neural network for the forward policy, with as many outputs as there are actions

module_PB = NeuralNet(
module_PB = MLP(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions - 1,
trunk=module_PF.trunk # We share all the parameters of P_F and P_B, except for the last layer
Expand All @@ -102,7 +102,7 @@ optimizer.add_param_group({"params": gfn.logz_parameters(), "lr": 1e-1})

# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration
for i in (pbar := tqdm(range(1000))):
trajectories = sampler.sample_trajectories(env=env, n_trajectories=16)
trajectories = sampler.sample_trajectories(env=env, n=16)
optimizer.zero_grad()
loss = gfn.loss(env, trajectories)
loss.backward()
Expand All @@ -121,24 +121,24 @@ from gfn.gflownet import SubTBGFlowNet
from gfn.gym import HyperGrid # We use the hyper grid environment
from gfn.modules import DiscretePolicyEstimator, ScalarEstimator
from gfn.samplers import Sampler
from gfn.utils import NeuralNet # NeuralNet is a simple multi-layer perceptron (MLP)
from gfn.utils.modules import MLP # MLP is a simple multi-layer perceptron (MLP)

# 1 - We define the environment.
env = HyperGrid(ndim=4, height=8, R0=0.01) # Grid of size 8x8x8x8

# 2 - We define the needed modules (neural networks).
# The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator
module_PF = NeuralNet(
module_PF = MLP(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions
) # Neural network for the forward policy, with as many outputs as there are actions

module_PB = NeuralNet(
module_PB = MLP(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions - 1,
trunk=module_PF.trunk # We share all the parameters of P_F and P_B, except for the last layer
)
module_logF = NeuralNet(
module_logF = MLP(
input_dim=env.preprocessor.output_dim,
output_dim=1, # Important for ScalarEstimators!
)
Expand All @@ -161,7 +161,7 @@ optimizer.add_param_group({"params": gfn.logF_parameters(), "lr": 1e-2})

# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration
for i in (pbar := tqdm(range(1000))):
trajectories = sampler.sample_trajectories(env=env, n_trajectories=16)
trajectories = sampler.sample_trajectories(env=env, n=16)
optimizer.zero_grad()
loss = gfn.loss(env, trajectories)
loss.backward()
Expand Down Expand Up @@ -238,12 +238,12 @@ In most cases, one needs to sample complete trajectories. From a batch of trajec

Training GFlowNets requires one or multiple estimators, called `GFNModule`s, which is an abstract subclass of `torch.nn.Module`. In addition to the usual `forward` function, `GFNModule`s need to implement a `required_output_dim` attribute, to ensure that the outputs have the required dimension for the task at hand; and some (but not all) need to implement a `to_probability_distribution` function.

- `DiscretePolicyEstimator` is a `GFNModule` that defines the policies $P_F(. \mid s)$ and $P_B(. \mid s)$ for discrete environments. When `is_backward=False`, the required output dimension is `n = env.n_actions`, and when `is_backward=True`, it is `n = env.n_actions - 1`. These `n` numbers represent the logits of a Categorical distribution. The corresponding `to_probability_distribution` function transforms the logits by masking illegal actions (according to the forward or backward masks), then return a Categorical distribution. The masking is done by setting the corresponding logit to $-\infty$. The function also includes exploration parameters, in order to define a tempered version of $P_F$, or a mixture of $P_F$ with a uniform distribution. `DiscretePolicyEstimator`` with `is_backward=False`` can be used to represent log-edge-flow estimators $\log F(s \rightarrow s')$.
- `DiscretePolicyEstimator` is a `GFNModule` that defines the policies $P_F(. \mid s)$ and $P_B(. \mid s)$ for discrete environments. When `is_backward=False`, the required output dimension is `n = env.n_actions`, and when `is_backward=True`, it is `n = env.n_actions - 1`. These `n` numbers represent the logits of a Categorical distribution. The corresponding `to_probability_distribution` function transforms the logits by masking illegal actions (according to the forward or backward masks), then return a Categorical distribution. The masking is done by setting the corresponding logit to $-\infty$. The function also includes exploration parameters, in order to define a tempered version of $P_F$, or a mixture of $P_F$ with a uniform distribution. `DiscretePolicyEstimator` with `is_backward=False` can be used to represent log-edge-flow estimators $\log F(s \rightarrow s')$.
- `ScalarModule` is a simple module with required output dimension 1. It is useful to define log-state flows $\log F(s)$.

For non-discrete environments, the user needs to specify their own policies $P_F$ and $P_B$. The module, taking as input a batch of states (as a `States`) object, should return the batched parameters of a `torch.Distribution`. The distribution depends on the environment. The `to_probability_distribution` function handles the conversion of the parameter outputs to an actual batched `Distribution` object, that implements at least the `sample` and `log_prob` functions. An example is provided [here](https://github.com/saleml/torchgfn/tree/master/src/gfn/gym/helpers/box_utils.py), for a square environment in which the forward policy has support either on a quarter disk, or on an arc-circle, such that the angle, and the radius (for the quarter disk part) are scaled samples from a mixture of Beta distributions. The provided example shows an intricate scenario, and it is not expected that user defined environment need this much level of details.

In general, (and perhaps obviously) the `to_probability_distribution` method is used to calculate a probability distribution from a policy. Therefore, in order to go off-policy, one needs to modify the computations in this method during sampling. One accomplishes this using `policy_kwargs`, a `dict` of kwarg-value pairs which are used by the `Estimator` when calculating the new policy. In the discrete case, where common settings apply, one can see their use in `DiscretePolicyEstimator`'s `to_probability_distribution` method by passing a softmax `temperature`, `sf_bias` (a scalar to subtract from the exit action logit) or `epsilon` which allows for e-greedy style exploration. In the continuous case, it is not possible to forsee the methods used for off-policy exploration (as it depends on the details of the `to_probability_distribution` method, which is not generic for continuous GFNs), so this must be handled by the user, using custom `policy_kwargs`.
In general, (and perhaps obviously) the `to_probability_distribution` method is used to calculate a probability distribution from a policy. Therefore, in order to go off-policy, one needs to modify the computations in this method during sampling. One accomplishes this using `policy_kwargs`, a `dict` of kwarg-value pairs which are used by the `Estimator` when calculating the new policy. In the discrete case, where common settings apply, one can see their use in `DiscretePolicyEstimator`'s `to_probability_distribution` method by passing a softmax `temperature`, `sf_bias` (a scalar to subtract from the exit action logit) or `epsilon` which allows for e-greedy style exploration. In the continuous case, it is not possible to foresee the methods used for off-policy exploration (as it depends on the details of the `to_probability_distribution` method, which is not generic for continuous GFNs), so this must be handled by the user, using custom `policy_kwargs`.

In all `GFNModule`s, note that the input of the `forward` function is a `States` object. Meaning that they first need to be transformed to tensors. However, `states.tensor` does not necessarily include the structure that a neural network can used to generalize. It is common in these scenarios to have a function that transforms these raw tensor states to ones where the structure is clearer, via a `Preprocessor` object, that is part of the environment. More on this [here](https://github.com/saleml/torchgfn/tree/master/tutorials/ENV.md). The default preprocessor of an environment is the identity preprocessor. The `forward` pass thus first calls the `preprocessor` attribute of the environment on `States`, before performing any transformation. The `preprocessor` is thus an attribute of the module. If it is not explicitly defined, it is set to the identity preprocessor.

Expand Down
2 changes: 1 addition & 1 deletion src/gfn/containers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
from .replay_buffer import PrioritizedReplayBuffer, ReplayBuffer
from .trajectories import Trajectories
from .transitions import Transitions
10 changes: 4 additions & 6 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


class ReplayBuffer:
"""A replay buffer of trajectories or transitions.
"""A replay buffer of trajectories, transitions, or states.
Attributes:
env: the Environment instance.
Expand All @@ -40,17 +40,15 @@ def __init__(
self.env = env
self.capacity = capacity
self.terminating_states = None
self.objects_type = objects_type
if objects_type == "trajectories":
self.training_objects = Trajectories(env)
self.objects_type = "trajectories"
elif objects_type == "transitions":
self.training_objects = Transitions(env)
self.objects_type = "transitions"
elif objects_type == "states":
self.training_objects = env.states_from_batch_shape((0,))
self.terminating_states = env.states_from_batch_shape((0,))
self.terminating_states.log_rewards = torch.zeros((0,), device=env.device)
self.objects_type = "states"
else:
raise ValueError(f"Unknown objects_type: {objects_type}")

Expand Down Expand Up @@ -146,7 +144,7 @@ def __init__(
def _add_objs(
self,
training_objects: Transitions | Trajectories | tuple[States],
terminating_states: States | None = None
terminating_states: States | None = None,
):
"""Adds a training object to the buffer."""
# Adds the objects to the buffer.
Expand Down Expand Up @@ -187,7 +185,7 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]):
training_objects = training_objects[ix]

# Filter all batch logrewards lower than the smallest logreward in buffer.
min_reward_in_buffer = self.training_objects.log_rewards.min()
min_reward_in_buffer = self.training_objects.log_rewards.min() # type: ignore # FIXME
idx_bigger_rewards = training_objects.log_rewards >= min_reward_in_buffer
training_objects = training_objects[idx_bigger_rewards]

Expand Down
56 changes: 28 additions & 28 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,18 @@ def __init__(

def states_from_tensor(self, tensor: torch.Tensor):
"""Wraps the supplied Tensor in a States instance.
Args:
tensor: The tensor of shape "state_shape" representing the states.
Returns:
States: An instance of States.
"""
return self.States(tensor)

def states_from_batch_shape(self, batch_shape: Tuple):
"""Returns a batch of s0 states with a given batch_shape.
Args:
batch_shape: Tuple representing the shape of the batch of states.
Expand All @@ -98,38 +98,36 @@ def states_from_batch_shape(self, batch_shape: Tuple):

def actions_from_tensor(self, tensor: torch.Tensor):
"""Wraps the supplied Tensor an an Actions instance.
Args:
tensor: The tensor of shape "action_shape" representing the actions.
Returns:
Actions: An instance of Actions.
"""
return self.Actions(tensor)

def actions_from_batch_shape(self, batch_shape: Tuple):
"""Returns a batch of dummy actions with the supplied batch_shape.
Args:
batch_shape: Tuple representing the shape of the batch of actions.
Returns:
Actions: A batch of dummy actions.
"""
return self.Actions.make_dummy_actions(batch_shape)

# To be implemented by the User.
@abstractmethod
def step(
self, states: States, actions: Actions
) -> torch.Tensor:
def step(self, states: States, actions: Actions) -> torch.Tensor:
"""Function that takes a batch of states and actions and returns a batch of next
states. Does not need to check whether the actions are valid or the states are sink states.
Args:
states: A batch of states.
actions: A batch of actions.
Returns:
torch.Tensor: A batch of next states.
"""
Expand All @@ -140,11 +138,11 @@ def backward_step( # TODO: rename to backward_step, other method becomes _backw
) -> torch.Tensor:
"""Function that takes a batch of states and actions and returns a batch of previous
states. Does not need to check whether the actions are valid or the states are sink states.
Args:
states: A batch of states.
actions: A batch of actions.
Returns:
torch.Tensor: A batch of previous states.
"""
Expand Down Expand Up @@ -312,7 +310,7 @@ def reward(self, final_states: States) -> torch.Tensor:
Args:
final_states: A batch of final states.
Returns:
torch.Tensor: Tensor of shape "batch_shape" containing the rewards.
"""
Expand All @@ -321,10 +319,10 @@ def reward(self, final_states: States) -> torch.Tensor:
def log_reward(self, final_states: States) -> torch.Tensor:
"""Calculates the log reward.
This or reward must be implemented.
Args:
final_states: A batch of final states.
Returns:
torch.Tensor: Tensor of shape "batch_shape" containing the log rewards.
"""
Expand All @@ -337,6 +335,13 @@ def log_partition(self) -> float:
"The environment does not support enumeration of states"
)

@property
def true_dist_pmf(self) -> torch.Tensor:
"Returns a one-dimensional tensor representing the true distribution."
raise NotImplementedError(
"The environment does not support enumeration of states"
)


class DiscreteEnv(Env, ABC):
"""
Expand Down Expand Up @@ -386,7 +391,6 @@ def __init__(
assert dummy_action.shape == action_shape
assert exit_action.shape == action_shape


self.n_actions = n_actions # Before init, for compatibility with States.
super().__init__(
s0,
Expand All @@ -403,10 +407,10 @@ def __init__(

def states_from_tensor(self, tensor: torch.Tensor):
"""Wraps the supplied Tensor in a States instance & updates masks.
Args:
tensor: The tensor of shape "state_shape" representing the states.
Returns:
States: An instance of States.
"""
Expand Down Expand Up @@ -489,29 +493,25 @@ def _step(self, states: DiscreteStates, actions: Actions) -> States:
) # TODO: update_masks is owned by the env, not the states!!
return new_states

def get_states_indices(
self, states: DiscreteStates
) -> torch.Tensor:
def get_states_indices(self, states: DiscreteStates) -> torch.Tensor:
"""Returns the indices of the states in the environment.
Args:
states: The batch of states.
Returns:
torch.Tensor: Tensor of shape "batch_shape" containing the indices of the states.
"""
return NotImplementedError(
"The environment does not support enumeration of states"
)

def get_terminating_states_indices(
self, states: DiscreteStates
) -> torch.Tensor:
def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor:
"""Returns the indices of the terminating states in the environment.
Args:
states: The batch of states.
Returns:
torch.Tensor: Tensor of shape "batch_shape" containing the indices of the terminating states.
"""
Expand Down
Loading

0 comments on commit 57cc269

Please sign in to comment.