Skip to content

Commit

Permalink
black / isort
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Nov 29, 2023
1 parent f12cbec commit cdffab1
Show file tree
Hide file tree
Showing 12 changed files with 48 additions and 47 deletions.
12 changes: 7 additions & 5 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ def __init__(
)
assert len(self.states.batch_shape) == 2
self.actions = (
actions
if actions is not None
else env.actions_from_batch_shape((0, 0))
actions if actions is not None else env.actions_from_batch_shape((0, 0))
)
assert len(self.actions.batch_shape) == 2
self.when_is_done = (
Expand Down Expand Up @@ -236,9 +234,13 @@ def extend(self, other: Trajectories) -> None:

# Either set, or append, estimator outputs if they exist in the submitted
# trajectory.
if self.estimator_outputs is None and isinstance(other.estimator_outputs, Tensor):
if self.estimator_outputs is None and isinstance(
other.estimator_outputs, Tensor
):
self.estimator_outputs = other.estimator_outputs
elif isinstance(self.estimator_outputs, Tensor) and isinstance(other.estimator_outputs, Tensor):
elif isinstance(self.estimator_outputs, Tensor) and isinstance(
other.estimator_outputs, Tensor
):
batch_shape = self.actions.batch_shape
n_bs = len(batch_shape)
output_dtype = self.estimator_outputs.dtype
Expand Down
4 changes: 1 addition & 3 deletions src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ def __init__(
assert len(self.states.batch_shape) == 1

self.actions = (
actions
if actions is not None
else env.actions_from_batch_shape((0,))
actions if actions is not None else env.actions_from_batch_shape((0,))
)
self.is_done = (
is_done
Expand Down
20 changes: 9 additions & 11 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Optional, Tuple, Union

import torch
from torchtyping import TensorType as TT
from torch import Tensor
from torchtyping import TensorType as TT

from gfn.actions import Actions
from gfn.preprocessors import IdentityPreprocessor, Preprocessor
Expand All @@ -12,6 +12,7 @@
# Errors
NonValidActionsError = type("NonValidActionsError", (ValueError,), {})


def get_device(device_str, default_device):
return torch.device(device_str) if device_str is not None else default_device

Expand Down Expand Up @@ -130,6 +131,7 @@ def make_States_class(self) -> type[States]:

class DefaultEnvState(States):
"""Defines a States class for this environment."""

state_shape = env.state_shape
s0 = env.s0
sf = env.sf
Expand Down Expand Up @@ -215,9 +217,7 @@ def _step(
not_done_states = new_states[~new_sink_states_idx]
not_done_actions = actions[~new_sink_states_idx]

new_not_done_states_tensor = self.step(
not_done_states, not_done_actions
)
new_not_done_states_tensor = self.step(not_done_states, not_done_actions)
# TODO: Why is this here? Should it be removed?
# if isinstance(new_states, DiscreteStates):
# new_not_done_states.masks = self.update_masks(not_done_states, not_done_actions)
Expand Down Expand Up @@ -247,9 +247,7 @@ def _backward_step(
)

# Calculate the backward step, and update only the states which are not Done.
new_not_done_states_tensor = self.backward_step(
valid_states, valid_actions
)
new_not_done_states_tensor = self.backward_step(valid_states, valid_actions)
new_states.tensor[valid_states_idx] = new_not_done_states_tensor

if isinstance(new_states, DiscreteStates):
Expand Down Expand Up @@ -316,7 +314,7 @@ def __init__(
if isinstance(dummy_action, type(None)):
dummy_action = torch.tensor([-1], device=device)

# The default exit action index is the final element of the action space.
# The default exit action index is the final element of the action space.
if isinstance(exit_action, type(None)):
exit_action = torch.tensor([n_actions - 1], device=device)

Expand Down Expand Up @@ -382,7 +380,6 @@ def make_States_class(self) -> type[States]:
env = self

class DiscreteEnvStates(DiscreteStates):

state_shape = env.state_shape
s0 = env.s0
sf = env.sf
Expand Down Expand Up @@ -413,7 +410,9 @@ def is_action_valid(
def _step(self, states: DiscreteStates, actions: Actions) -> States:
"""Calls the core self._step method of the parent class, and updates masks."""
new_states = super()._step(states, actions)
self.update_masks(new_states) # TODO: update_masks is owned by the env, not the states!!
self.update_masks(
new_states
) # TODO: update_masks is owned by the env, not the states!!
return new_states

def get_states_indices(
Expand Down Expand Up @@ -470,4 +469,3 @@ def terminating_states(self) -> DiscreteStates:
return NotImplementedError(
"The environment does not support enumeration of states"
)

2 changes: 1 addition & 1 deletion src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from abc import ABC, abstractmethod
from typing import Generic, Tuple, TypeVar, Union
import math

import torch
import torch.nn as nn
Expand Down
12 changes: 8 additions & 4 deletions src/gfn/gym/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ def __init__(
self.delta = delta
self.epsilon = epsilon
s0 = torch.tensor([0.0, 0.0], device=torch.device(device_str))
exit_action = torch.tensor([-float("inf"), -float("inf")], device=torch.device(device_str))
dummy_action = torch.tensor([float("inf"), float("inf")], device=torch.device(device_str))
exit_action = torch.tensor(
[-float("inf"), -float("inf")], device=torch.device(device_str)
)
dummy_action = torch.tensor(
[float("inf"), float("inf")], device=torch.device(device_str)
)

self.R0 = R0
self.R1 = R1
Expand All @@ -41,8 +45,8 @@ def __init__(
)

def make_random_states_tensor(
self, batch_shape: Tuple[int, ...]
) -> TT["batch_shape", 2, torch.float]:
self, batch_shape: Tuple[int, ...]
) -> TT["batch_shape", 2, torch.float]:
return torch.rand(batch_shape + (2,), device=self.device)

def step(
Expand Down
4 changes: 2 additions & 2 deletions src/gfn/gym/discrete_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Literal, Tuple

import torch
from torch import Tensor
import torch.nn as nn
from torch import Tensor
from torchtyping import TensorType as TT

from gfn.actions import Actions
Expand Down Expand Up @@ -89,7 +89,7 @@ def __init__(

super().__init__(
s0=s0,
state_shape=(self.ndim, ),
state_shape=(self.ndim,),
# dummy_action=,
# exit_action=,
n_actions=n_actions,
Expand Down
10 changes: 7 additions & 3 deletions src/gfn/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABC, abstractmethod
from math import prod
from typing import ClassVar, Optional, Sequence, cast, Callable
from typing import Callable, ClassVar, Optional, Sequence, cast

import torch
from torchtyping import TensorType as TT
Expand Down Expand Up @@ -49,7 +49,11 @@ class States(ABC):
sf: ClassVar[
TT["state_shape", torch.float]
] # Dummy state, used to pad a batch of states
make_random_states_tensor: Callable = lambda x: (_ for _ in ()).throw(NotImplementedError("The environment does not support initialization of random states."))
make_random_states_tensor: Callable = lambda x: (_ for _ in ()).throw(
NotImplementedError(
"The environment does not support initialization of random states."
)
)

def __init__(self, tensor: TT["batch_shape", "state_shape"]):
"""Initalize the State container with a batch of states.
Expand Down Expand Up @@ -267,6 +271,7 @@ class DiscreteStates(States, ABC):
forward_masks: A boolean tensor of allowable forward policy actions.
backward_masks: A boolean tensor of allowable backward policy actions.
"""

n_actions: ClassVar[int]
device: ClassVar[torch.device]

Expand All @@ -276,7 +281,6 @@ def __init__(
forward_masks: Optional[TT["batch_shape", "n_actions", torch.bool]] = None,
backward_masks: Optional[TT["batch_shape", "n_actions - 1", torch.bool]] = None,
) -> None:

"""Initalize a DiscreteStates container with a batch of states and masks.
Args:
tensor: A batch of states.
Expand Down
2 changes: 1 addition & 1 deletion tutorials/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from dataclasses import dataclass

import pytest
import numpy as np
import pytest

from .train_box import main as train_box_main
from .train_discreteebm import main as train_discreteebm_main
Expand Down
4 changes: 1 addition & 3 deletions tutorials/examples/train_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,7 @@ def main(args): # noqa: C901
print(f"current optimizer LR: {optimizer.param_groups[0]['lr']}")

trajectories = gflownet.sample_trajectories(
env,
sample_off_policy=False,
n_samples=args.batch_size
env, sample_off_policy=False, n_samples=args.batch_size
)

training_samples = gflownet.to_training_samples(trajectories)
Expand Down
8 changes: 2 additions & 6 deletions tutorials/examples/train_discreteebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@
from gfn.gflownet import FMGFlowNet
from gfn.gym import DiscreteEBM
from gfn.modules import DiscretePolicyEstimator
from gfn.utils.common import validate
from gfn.utils.common import set_seed, validate
from gfn.utils.modules import NeuralNet, Tabular

from gfn.utils.common import set_seed

DEFAULT_SEED = 4444


Expand Down Expand Up @@ -72,9 +70,7 @@ def main(args): # noqa: C901
validation_info = {"l1_dist": float("inf")}
for iteration in trange(n_iterations):
trajectories = gflownet.sample_trajectories(
env,
off_policy=False,
n_samples=args.batch_size
env, off_policy=False, n_samples=args.batch_size
)
training_samples = gflownet.to_training_samples(trajectories)

Expand Down
8 changes: 4 additions & 4 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,9 @@
)
from gfn.gym import HyperGrid
from gfn.modules import DiscretePolicyEstimator, ScalarEstimator
from gfn.utils.common import validate
from gfn.utils.common import set_seed, validate
from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular

from gfn.utils.common import set_seed

DEFAULT_SEED = 4444


Expand Down Expand Up @@ -225,7 +223,9 @@ def main(args): # noqa: C901
n_iterations = args.n_trajectories // args.batch_size
validation_info = {"l1_dist": float("inf")}
for iteration in trange(n_iterations):
trajectories = gflownet.sample_trajectories(env, n_samples=args.batch_size, sample_off_policy=off_policy_sampling)
trajectories = gflownet.sample_trajectories(
env, n_samples=args.batch_size, sample_off_policy=off_policy_sampling
)
training_samples = gflownet.to_training_samples(trajectories)
if replay_buffer is not None:
with torch.no_grad():
Expand Down
9 changes: 5 additions & 4 deletions tutorials/examples/train_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from tqdm import trange

from gfn.gflownet import TBGFlowNet # TODO: Extend to SubTBGFlowNet
from gfn.gym.line import Line
from gfn.modules import GFNModule
from gfn.states import States
from gfn.utils import NeuralNet
from gfn.gym.line import Line
from gfn.utils.common import set_seed


Expand Down Expand Up @@ -113,7 +113,9 @@ def log_prob(self, sampled_actions):

actions_to_eval[~exit_idx] = sampled_actions[~exit_idx]
if sum(~exit_idx) > 0:
logprobs[~exit_idx] = self.dist.log_prob(actions_to_eval)[~exit_idx].unsqueeze(-1)
logprobs[~exit_idx] = self.dist.log_prob(actions_to_eval)[
~exit_idx
].unsqueeze(-1)

return logprobs.squeeze(-1)

Expand Down Expand Up @@ -187,6 +189,7 @@ def to_probability_distribution(
n_steps=self.n_steps_per_trajectory,
)


def train(
gflownet,
env,
Expand Down Expand Up @@ -220,7 +223,6 @@ def train(
scale_schedule = np.linspace(exploration_var_starting_val, 0, n_iterations)

for iteration in tbar:

optimizer.zero_grad()
# Off Policy Sampling.
trajectories = gflownet.sample_trajectories(
Expand Down Expand Up @@ -259,7 +261,6 @@ def train(


if __name__ == "__main__":

environment = Line(
mus=[2, 5],
sigmas=[0.5, 0.5],
Expand Down

0 comments on commit cdffab1

Please sign in to comment.