diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py
index 38f926d8..2b42eb2c 100644
--- a/src/gfn/containers/trajectories.py
+++ b/src/gfn/containers/trajectories.py
@@ -78,9 +78,7 @@ def __init__(
         )
         assert len(self.states.batch_shape) == 2
         self.actions = (
-            actions
-            if actions is not None
-            else env.actions_from_batch_shape((0, 0))
+            actions if actions is not None else env.actions_from_batch_shape((0, 0))
         )
         assert len(self.actions.batch_shape) == 2
         self.when_is_done = (
@@ -236,9 +234,13 @@ def extend(self, other: Trajectories) -> None:
 
         # Either set, or append, estimator outputs if they exist in the submitted
         # trajectory.
-        if self.estimator_outputs is None and isinstance(other.estimator_outputs, Tensor):
+        if self.estimator_outputs is None and isinstance(
+            other.estimator_outputs, Tensor
+        ):
             self.estimator_outputs = other.estimator_outputs
-        elif isinstance(self.estimator_outputs, Tensor) and isinstance(other.estimator_outputs, Tensor):
+        elif isinstance(self.estimator_outputs, Tensor) and isinstance(
+            other.estimator_outputs, Tensor
+        ):
             batch_shape = self.actions.batch_shape
             n_bs = len(batch_shape)
             output_dtype = self.estimator_outputs.dtype
diff --git a/src/gfn/containers/transitions.py b/src/gfn/containers/transitions.py
index baddfa34..4b15f05e 100644
--- a/src/gfn/containers/transitions.py
+++ b/src/gfn/containers/transitions.py
@@ -73,9 +73,7 @@ def __init__(
         assert len(self.states.batch_shape) == 1
 
         self.actions = (
-            actions
-            if actions is not None
-            else env.actions_from_batch_shape((0,))
+            actions if actions is not None else env.actions_from_batch_shape((0,))
         )
         self.is_done = (
             is_done
diff --git a/src/gfn/env.py b/src/gfn/env.py
index d8b681c8..7d79def5 100644
--- a/src/gfn/env.py
+++ b/src/gfn/env.py
@@ -2,8 +2,8 @@
 from typing import Optional, Tuple, Union
 
 import torch
-from torchtyping import TensorType as TT
 from torch import Tensor
+from torchtyping import TensorType as TT
 
 from gfn.actions import Actions
 from gfn.preprocessors import IdentityPreprocessor, Preprocessor
@@ -12,6 +12,7 @@
 # Errors
 NonValidActionsError = type("NonValidActionsError", (ValueError,), {})
 
+
 def get_device(device_str, default_device):
     return torch.device(device_str) if device_str is not None else default_device
 
@@ -130,6 +131,7 @@ def make_States_class(self) -> type[States]:
 
         class DefaultEnvState(States):
             """Defines a States class for this environment."""
+
             state_shape = env.state_shape
             s0 = env.s0
             sf = env.sf
@@ -215,9 +217,7 @@ def _step(
         not_done_states = new_states[~new_sink_states_idx]
         not_done_actions = actions[~new_sink_states_idx]
 
-        new_not_done_states_tensor = self.step(
-            not_done_states, not_done_actions
-        )
+        new_not_done_states_tensor = self.step(not_done_states, not_done_actions)
         # TODO: Why is this here? Should it be removed?
         # if isinstance(new_states, DiscreteStates):
         #     new_not_done_states.masks = self.update_masks(not_done_states, not_done_actions)
@@ -247,9 +247,7 @@ def _backward_step(
             )
 
         # Calculate the backward step, and update only the states which are not Done.
-        new_not_done_states_tensor = self.backward_step(
-            valid_states, valid_actions
-        )
+        new_not_done_states_tensor = self.backward_step(valid_states, valid_actions)
         new_states.tensor[valid_states_idx] = new_not_done_states_tensor
 
         if isinstance(new_states, DiscreteStates):
@@ -316,7 +314,7 @@ def __init__(
         if isinstance(dummy_action, type(None)):
             dummy_action = torch.tensor([-1], device=device)
 
-       # The default exit action index is the final element of the action space.
+        # The default exit action index is the final element of the action space.
         if isinstance(exit_action, type(None)):
             exit_action = torch.tensor([n_actions - 1], device=device)
 
@@ -382,7 +380,6 @@ def make_States_class(self) -> type[States]:
         env = self
 
         class DiscreteEnvStates(DiscreteStates):
-
             state_shape = env.state_shape
             s0 = env.s0
             sf = env.sf
@@ -413,7 +410,9 @@ def is_action_valid(
     def _step(self, states: DiscreteStates, actions: Actions) -> States:
         """Calls the core self._step method of the parent class, and updates masks."""
         new_states = super()._step(states, actions)
-        self.update_masks(new_states)  # TODO: update_masks is owned by the env, not the states!!
+        self.update_masks(
+            new_states
+        )  # TODO: update_masks is owned by the env, not the states!!
         return new_states
 
     def get_states_indices(
@@ -470,4 +469,3 @@ def terminating_states(self) -> DiscreteStates:
         return NotImplementedError(
             "The environment does not support enumeration of states"
         )
-
diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py
index 0656ba64..e7d80921 100644
--- a/src/gfn/gflownet/base.py
+++ b/src/gfn/gflownet/base.py
@@ -1,6 +1,6 @@
+import math
 from abc import ABC, abstractmethod
 from typing import Generic, Tuple, TypeVar, Union
-import math
 
 import torch
 import torch.nn as nn
diff --git a/src/gfn/gym/box.py b/src/gfn/gym/box.py
index 7c070682..22ed18a7 100644
--- a/src/gfn/gym/box.py
+++ b/src/gfn/gym/box.py
@@ -25,8 +25,12 @@ def __init__(
         self.delta = delta
         self.epsilon = epsilon
         s0 = torch.tensor([0.0, 0.0], device=torch.device(device_str))
-        exit_action = torch.tensor([-float("inf"), -float("inf")], device=torch.device(device_str))
-        dummy_action = torch.tensor([float("inf"), float("inf")], device=torch.device(device_str))
+        exit_action = torch.tensor(
+            [-float("inf"), -float("inf")], device=torch.device(device_str)
+        )
+        dummy_action = torch.tensor(
+            [float("inf"), float("inf")], device=torch.device(device_str)
+        )
 
         self.R0 = R0
         self.R1 = R1
@@ -41,8 +45,8 @@ def __init__(
         )
 
     def make_random_states_tensor(
-            self, batch_shape: Tuple[int, ...]
-        ) -> TT["batch_shape", 2, torch.float]:
+        self, batch_shape: Tuple[int, ...]
+    ) -> TT["batch_shape", 2, torch.float]:
         return torch.rand(batch_shape + (2,), device=self.device)
 
     def step(
diff --git a/src/gfn/gym/discrete_ebm.py b/src/gfn/gym/discrete_ebm.py
index 85495f95..644d6cbd 100644
--- a/src/gfn/gym/discrete_ebm.py
+++ b/src/gfn/gym/discrete_ebm.py
@@ -2,8 +2,8 @@
 from typing import Literal, Tuple
 
 import torch
-from torch import Tensor
 import torch.nn as nn
+from torch import Tensor
 from torchtyping import TensorType as TT
 
 from gfn.actions import Actions
@@ -89,7 +89,7 @@ def __init__(
 
         super().__init__(
             s0=s0,
-            state_shape=(self.ndim, ),
+            state_shape=(self.ndim,),
             # dummy_action=,
             # exit_action=,
             n_actions=n_actions,
diff --git a/src/gfn/states.py b/src/gfn/states.py
index 3fd209d4..883765b8 100644
--- a/src/gfn/states.py
+++ b/src/gfn/states.py
@@ -2,7 +2,7 @@
 
 from abc import ABC, abstractmethod
 from math import prod
-from typing import ClassVar, Optional, Sequence, cast, Callable
+from typing import Callable, ClassVar, Optional, Sequence, cast
 
 import torch
 from torchtyping import TensorType as TT
@@ -49,7 +49,11 @@ class States(ABC):
     sf: ClassVar[
         TT["state_shape", torch.float]
     ]  # Dummy state, used to pad a batch of states
-    make_random_states_tensor: Callable = lambda x: (_ for _ in ()).throw(NotImplementedError("The environment does not support initialization of random states."))
+    make_random_states_tensor: Callable = lambda x: (_ for _ in ()).throw(
+        NotImplementedError(
+            "The environment does not support initialization of random states."
+        )
+    )
 
     def __init__(self, tensor: TT["batch_shape", "state_shape"]):
         """Initalize the State container with a batch of states.
@@ -267,6 +271,7 @@ class DiscreteStates(States, ABC):
         forward_masks: A boolean tensor of allowable forward policy actions.
         backward_masks:  A boolean tensor of allowable backward policy actions.
     """
+
     n_actions: ClassVar[int]
     device: ClassVar[torch.device]
 
@@ -276,7 +281,6 @@ def __init__(
         forward_masks: Optional[TT["batch_shape", "n_actions", torch.bool]] = None,
         backward_masks: Optional[TT["batch_shape", "n_actions - 1", torch.bool]] = None,
     ) -> None:
-
         """Initalize a DiscreteStates container with a batch of states and masks.
         Args:
             tensor: A batch of states.
diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py
index 18801016..192a5dcb 100644
--- a/tutorials/examples/test_scripts.py
+++ b/tutorials/examples/test_scripts.py
@@ -5,8 +5,8 @@
 
 from dataclasses import dataclass
 
-import pytest
 import numpy as np
+import pytest
 
 from .train_box import main as train_box_main
 from .train_discreteebm import main as train_discreteebm_main
diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py
index 0ea3e913..e9ecbeae 100644
--- a/tutorials/examples/train_box.py
+++ b/tutorials/examples/train_box.py
@@ -233,9 +233,7 @@ def main(args):  # noqa: C901
             print(f"current optimizer LR: {optimizer.param_groups[0]['lr']}")
 
         trajectories = gflownet.sample_trajectories(
-            env,
-            sample_off_policy=False,
-            n_samples=args.batch_size
+            env, sample_off_policy=False, n_samples=args.batch_size
         )
 
         training_samples = gflownet.to_training_samples(trajectories)
diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py
index 5fdb2591..3a441648 100644
--- a/tutorials/examples/train_discreteebm.py
+++ b/tutorials/examples/train_discreteebm.py
@@ -20,11 +20,9 @@
 from gfn.gflownet import FMGFlowNet
 from gfn.gym import DiscreteEBM
 from gfn.modules import DiscretePolicyEstimator
-from gfn.utils.common import validate
+from gfn.utils.common import set_seed, validate
 from gfn.utils.modules import NeuralNet, Tabular
 
-from gfn.utils.common import set_seed
-
 DEFAULT_SEED = 4444
 
 
@@ -72,9 +70,7 @@ def main(args):  # noqa: C901
     validation_info = {"l1_dist": float("inf")}
     for iteration in trange(n_iterations):
         trajectories = gflownet.sample_trajectories(
-            env,
-            off_policy=False,
-            n_samples=args.batch_size
+            env, off_policy=False, n_samples=args.batch_size
         )
         training_samples = gflownet.to_training_samples(trajectories)
 
diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py
index f8982727..517da98e 100644
--- a/tutorials/examples/train_hypergrid.py
+++ b/tutorials/examples/train_hypergrid.py
@@ -28,11 +28,9 @@
 )
 from gfn.gym import HyperGrid
 from gfn.modules import DiscretePolicyEstimator, ScalarEstimator
-from gfn.utils.common import validate
+from gfn.utils.common import set_seed, validate
 from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular
 
-from gfn.utils.common import set_seed
-
 DEFAULT_SEED = 4444
 
 
@@ -225,7 +223,9 @@ def main(args):  # noqa: C901
     n_iterations = args.n_trajectories // args.batch_size
     validation_info = {"l1_dist": float("inf")}
     for iteration in trange(n_iterations):
-        trajectories = gflownet.sample_trajectories(env, n_samples=args.batch_size, sample_off_policy=off_policy_sampling)
+        trajectories = gflownet.sample_trajectories(
+            env, n_samples=args.batch_size, sample_off_policy=off_policy_sampling
+        )
         training_samples = gflownet.to_training_samples(trajectories)
         if replay_buffer is not None:
             with torch.no_grad():
diff --git a/tutorials/examples/train_line.py b/tutorials/examples/train_line.py
index 744e9294..4e69c4ee 100644
--- a/tutorials/examples/train_line.py
+++ b/tutorials/examples/train_line.py
@@ -7,10 +7,10 @@
 from tqdm import trange
 
 from gfn.gflownet import TBGFlowNet  # TODO: Extend to SubTBGFlowNet
+from gfn.gym.line import Line
 from gfn.modules import GFNModule
 from gfn.states import States
 from gfn.utils import NeuralNet
-from gfn.gym.line import Line
 from gfn.utils.common import set_seed
 
 
@@ -113,7 +113,9 @@ def log_prob(self, sampled_actions):
 
         actions_to_eval[~exit_idx] = sampled_actions[~exit_idx]
         if sum(~exit_idx) > 0:
-            logprobs[~exit_idx] = self.dist.log_prob(actions_to_eval)[~exit_idx].unsqueeze(-1)
+            logprobs[~exit_idx] = self.dist.log_prob(actions_to_eval)[
+                ~exit_idx
+            ].unsqueeze(-1)
 
         return logprobs.squeeze(-1)
 
@@ -187,6 +189,7 @@ def to_probability_distribution(
             n_steps=self.n_steps_per_trajectory,
         )
 
+
 def train(
     gflownet,
     env,
@@ -220,7 +223,6 @@ def train(
     scale_schedule = np.linspace(exploration_var_starting_val, 0, n_iterations)
 
     for iteration in tbar:
-
         optimizer.zero_grad()
         # Off Policy Sampling.
         trajectories = gflownet.sample_trajectories(
@@ -259,7 +261,6 @@ def train(
 
 
 if __name__ == "__main__":
-
     environment = Line(
         mus=[2, 5],
         sigmas=[0.5, 0.5],