diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index 880314ea..c1679d9a 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -1,9 +1,10 @@ from __future__ import annotations import os -import torch from typing import TYPE_CHECKING, Literal +import torch + from gfn.containers.trajectories import Trajectories from gfn.containers.transitions import Transitions diff --git a/src/gfn/env.py b/src/gfn/env.py index 9b045ca3..d52306f6 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -389,7 +389,7 @@ class DiscreteEnvStates(DiscreteStates): def make_actions_class(self) -> type[Actions]: env = self - n_actions = self.n_actions + self.n_actions class DiscreteEnvActions(Actions): action_shape = env.action_shape diff --git a/src/gfn/states.py b/src/gfn/states.py index 38323abe..d7027873 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -1,6 +1,6 @@ from __future__ import annotations # This allows to use the class name in type hints -from abc import ABC, abstractmethod +from abc import ABC from copy import deepcopy from math import prod from typing import Callable, ClassVar, Optional, Sequence, cast diff --git a/testing/test_environments.py b/testing/test_environments.py index b110baac..5dbd4cc6 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -209,7 +209,9 @@ def test_box_fwd_step(delta: float): ] for failing_actions_list in failing_actions_lists_at_s0: - actions = env.actions_from_tensor(format_tensor(failing_actions_list, discrete=False)) + actions = env.actions_from_tensor( + format_tensor(failing_actions_list, discrete=False) + ) with pytest.raises(NonValidActionsError): states = env._step(states, actions)