diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 38f926d8..2b42eb2c 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -78,9 +78,7 @@ def __init__( ) assert len(self.states.batch_shape) == 2 self.actions = ( - actions - if actions is not None - else env.actions_from_batch_shape((0, 0)) + actions if actions is not None else env.actions_from_batch_shape((0, 0)) ) assert len(self.actions.batch_shape) == 2 self.when_is_done = ( @@ -236,9 +234,13 @@ def extend(self, other: Trajectories) -> None: # Either set, or append, estimator outputs if they exist in the submitted # trajectory. - if self.estimator_outputs is None and isinstance(other.estimator_outputs, Tensor): + if self.estimator_outputs is None and isinstance( + other.estimator_outputs, Tensor + ): self.estimator_outputs = other.estimator_outputs - elif isinstance(self.estimator_outputs, Tensor) and isinstance(other.estimator_outputs, Tensor): + elif isinstance(self.estimator_outputs, Tensor) and isinstance( + other.estimator_outputs, Tensor + ): batch_shape = self.actions.batch_shape n_bs = len(batch_shape) output_dtype = self.estimator_outputs.dtype diff --git a/src/gfn/containers/transitions.py b/src/gfn/containers/transitions.py index baddfa34..4b15f05e 100644 --- a/src/gfn/containers/transitions.py +++ b/src/gfn/containers/transitions.py @@ -73,9 +73,7 @@ def __init__( assert len(self.states.batch_shape) == 1 self.actions = ( - actions - if actions is not None - else env.actions_from_batch_shape((0,)) + actions if actions is not None else env.actions_from_batch_shape((0,)) ) self.is_done = ( is_done diff --git a/src/gfn/env.py b/src/gfn/env.py index d8b681c8..7d79def5 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -2,8 +2,8 @@ from typing import Optional, Tuple, Union import torch -from torchtyping import TensorType as TT from torch import Tensor +from torchtyping import TensorType as TT from gfn.actions import Actions from gfn.preprocessors import IdentityPreprocessor, Preprocessor @@ -12,6 +12,7 @@ # Errors NonValidActionsError = type("NonValidActionsError", (ValueError,), {}) + def get_device(device_str, default_device): return torch.device(device_str) if device_str is not None else default_device @@ -130,6 +131,7 @@ def make_States_class(self) -> type[States]: class DefaultEnvState(States): """Defines a States class for this environment.""" + state_shape = env.state_shape s0 = env.s0 sf = env.sf @@ -215,9 +217,7 @@ def _step( not_done_states = new_states[~new_sink_states_idx] not_done_actions = actions[~new_sink_states_idx] - new_not_done_states_tensor = self.step( - not_done_states, not_done_actions - ) + new_not_done_states_tensor = self.step(not_done_states, not_done_actions) # TODO: Why is this here? Should it be removed? # if isinstance(new_states, DiscreteStates): # new_not_done_states.masks = self.update_masks(not_done_states, not_done_actions) @@ -247,9 +247,7 @@ def _backward_step( ) # Calculate the backward step, and update only the states which are not Done. - new_not_done_states_tensor = self.backward_step( - valid_states, valid_actions - ) + new_not_done_states_tensor = self.backward_step(valid_states, valid_actions) new_states.tensor[valid_states_idx] = new_not_done_states_tensor if isinstance(new_states, DiscreteStates): @@ -316,7 +314,7 @@ def __init__( if isinstance(dummy_action, type(None)): dummy_action = torch.tensor([-1], device=device) - # The default exit action index is the final element of the action space. + # The default exit action index is the final element of the action space. if isinstance(exit_action, type(None)): exit_action = torch.tensor([n_actions - 1], device=device) @@ -382,7 +380,6 @@ def make_States_class(self) -> type[States]: env = self class DiscreteEnvStates(DiscreteStates): - state_shape = env.state_shape s0 = env.s0 sf = env.sf @@ -413,7 +410,9 @@ def is_action_valid( def _step(self, states: DiscreteStates, actions: Actions) -> States: """Calls the core self._step method of the parent class, and updates masks.""" new_states = super()._step(states, actions) - self.update_masks(new_states) # TODO: update_masks is owned by the env, not the states!! + self.update_masks( + new_states + ) # TODO: update_masks is owned by the env, not the states!! return new_states def get_states_indices( @@ -470,4 +469,3 @@ def terminating_states(self) -> DiscreteStates: return NotImplementedError( "The environment does not support enumeration of states" ) - diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 0656ba64..e7d80921 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -1,6 +1,6 @@ +import math from abc import ABC, abstractmethod from typing import Generic, Tuple, TypeVar, Union -import math import torch import torch.nn as nn diff --git a/src/gfn/gym/box.py b/src/gfn/gym/box.py index 7c070682..22ed18a7 100644 --- a/src/gfn/gym/box.py +++ b/src/gfn/gym/box.py @@ -25,8 +25,12 @@ def __init__( self.delta = delta self.epsilon = epsilon s0 = torch.tensor([0.0, 0.0], device=torch.device(device_str)) - exit_action = torch.tensor([-float("inf"), -float("inf")], device=torch.device(device_str)) - dummy_action = torch.tensor([float("inf"), float("inf")], device=torch.device(device_str)) + exit_action = torch.tensor( + [-float("inf"), -float("inf")], device=torch.device(device_str) + ) + dummy_action = torch.tensor( + [float("inf"), float("inf")], device=torch.device(device_str) + ) self.R0 = R0 self.R1 = R1 @@ -41,8 +45,8 @@ def __init__( ) def make_random_states_tensor( - self, batch_shape: Tuple[int, ...] - ) -> TT["batch_shape", 2, torch.float]: + self, batch_shape: Tuple[int, ...] + ) -> TT["batch_shape", 2, torch.float]: return torch.rand(batch_shape + (2,), device=self.device) def step( diff --git a/src/gfn/gym/discrete_ebm.py b/src/gfn/gym/discrete_ebm.py index 85495f95..644d6cbd 100644 --- a/src/gfn/gym/discrete_ebm.py +++ b/src/gfn/gym/discrete_ebm.py @@ -2,8 +2,8 @@ from typing import Literal, Tuple import torch -from torch import Tensor import torch.nn as nn +from torch import Tensor from torchtyping import TensorType as TT from gfn.actions import Actions @@ -89,7 +89,7 @@ def __init__( super().__init__( s0=s0, - state_shape=(self.ndim, ), + state_shape=(self.ndim,), # dummy_action=, # exit_action=, n_actions=n_actions, diff --git a/src/gfn/states.py b/src/gfn/states.py index 3fd209d4..883765b8 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from math import prod -from typing import ClassVar, Optional, Sequence, cast, Callable +from typing import Callable, ClassVar, Optional, Sequence, cast import torch from torchtyping import TensorType as TT @@ -49,7 +49,11 @@ class States(ABC): sf: ClassVar[ TT["state_shape", torch.float] ] # Dummy state, used to pad a batch of states - make_random_states_tensor: Callable = lambda x: (_ for _ in ()).throw(NotImplementedError("The environment does not support initialization of random states.")) + make_random_states_tensor: Callable = lambda x: (_ for _ in ()).throw( + NotImplementedError( + "The environment does not support initialization of random states." + ) + ) def __init__(self, tensor: TT["batch_shape", "state_shape"]): """Initalize the State container with a batch of states. @@ -267,6 +271,7 @@ class DiscreteStates(States, ABC): forward_masks: A boolean tensor of allowable forward policy actions. backward_masks: A boolean tensor of allowable backward policy actions. """ + n_actions: ClassVar[int] device: ClassVar[torch.device] @@ -276,7 +281,6 @@ def __init__( forward_masks: Optional[TT["batch_shape", "n_actions", torch.bool]] = None, backward_masks: Optional[TT["batch_shape", "n_actions - 1", torch.bool]] = None, ) -> None: - """Initalize a DiscreteStates container with a batch of states and masks. Args: tensor: A batch of states. diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index 18801016..192a5dcb 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -5,8 +5,8 @@ from dataclasses import dataclass -import pytest import numpy as np +import pytest from .train_box import main as train_box_main from .train_discreteebm import main as train_discreteebm_main diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index 0ea3e913..e9ecbeae 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -233,9 +233,7 @@ def main(args): # noqa: C901 print(f"current optimizer LR: {optimizer.param_groups[0]['lr']}") trajectories = gflownet.sample_trajectories( - env, - sample_off_policy=False, - n_samples=args.batch_size + env, sample_off_policy=False, n_samples=args.batch_size ) training_samples = gflownet.to_training_samples(trajectories) diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py index 5fdb2591..3a441648 100644 --- a/tutorials/examples/train_discreteebm.py +++ b/tutorials/examples/train_discreteebm.py @@ -20,11 +20,9 @@ from gfn.gflownet import FMGFlowNet from gfn.gym import DiscreteEBM from gfn.modules import DiscretePolicyEstimator -from gfn.utils.common import validate +from gfn.utils.common import set_seed, validate from gfn.utils.modules import NeuralNet, Tabular -from gfn.utils.common import set_seed - DEFAULT_SEED = 4444 @@ -72,9 +70,7 @@ def main(args): # noqa: C901 validation_info = {"l1_dist": float("inf")} for iteration in trange(n_iterations): trajectories = gflownet.sample_trajectories( - env, - off_policy=False, - n_samples=args.batch_size + env, off_policy=False, n_samples=args.batch_size ) training_samples = gflownet.to_training_samples(trajectories) diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index f8982727..517da98e 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -28,11 +28,9 @@ ) from gfn.gym import HyperGrid from gfn.modules import DiscretePolicyEstimator, ScalarEstimator -from gfn.utils.common import validate +from gfn.utils.common import set_seed, validate from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular -from gfn.utils.common import set_seed - DEFAULT_SEED = 4444 @@ -225,7 +223,9 @@ def main(args): # noqa: C901 n_iterations = args.n_trajectories // args.batch_size validation_info = {"l1_dist": float("inf")} for iteration in trange(n_iterations): - trajectories = gflownet.sample_trajectories(env, n_samples=args.batch_size, sample_off_policy=off_policy_sampling) + trajectories = gflownet.sample_trajectories( + env, n_samples=args.batch_size, sample_off_policy=off_policy_sampling + ) training_samples = gflownet.to_training_samples(trajectories) if replay_buffer is not None: with torch.no_grad(): diff --git a/tutorials/examples/train_line.py b/tutorials/examples/train_line.py index 744e9294..4e69c4ee 100644 --- a/tutorials/examples/train_line.py +++ b/tutorials/examples/train_line.py @@ -7,10 +7,10 @@ from tqdm import trange from gfn.gflownet import TBGFlowNet # TODO: Extend to SubTBGFlowNet +from gfn.gym.line import Line from gfn.modules import GFNModule from gfn.states import States from gfn.utils import NeuralNet -from gfn.gym.line import Line from gfn.utils.common import set_seed @@ -113,7 +113,9 @@ def log_prob(self, sampled_actions): actions_to_eval[~exit_idx] = sampled_actions[~exit_idx] if sum(~exit_idx) > 0: - logprobs[~exit_idx] = self.dist.log_prob(actions_to_eval)[~exit_idx].unsqueeze(-1) + logprobs[~exit_idx] = self.dist.log_prob(actions_to_eval)[ + ~exit_idx + ].unsqueeze(-1) return logprobs.squeeze(-1) @@ -187,6 +189,7 @@ def to_probability_distribution( n_steps=self.n_steps_per_trajectory, ) + def train( gflownet, env, @@ -220,7 +223,6 @@ def train( scale_schedule = np.linspace(exploration_var_starting_val, 0, n_iterations) for iteration in tbar: - optimizer.zero_grad() # Off Policy Sampling. trajectories = gflownet.sample_trajectories( @@ -259,7 +261,6 @@ def train( if __name__ == "__main__": - environment = Line( mus=[2, 5], sigmas=[0.5, 0.5],