From 0535e9f42e950de6e3ac7245fca5f9dad1747f34 Mon Sep 17 00:00:00 2001 From: Joseph Date: Mon, 27 Nov 2023 12:56:05 -0500 Subject: [PATCH 01/28] removed one order of magnitude precision required --- tutorials/examples/test_scripts.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index e9494ec0..cf1a5546 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -69,13 +69,13 @@ def test_hypergrid(ndim: int, height: int): args = HypergridArgs(ndim=ndim, height=height, n_trajectories=n_trajectories) final_l1_dist = train_hypergrid_main(args) if ndim == 2 and height == 8: - assert np.isclose(final_l1_dist, 8.78e-4, atol=1e-5) + assert np.isclose(final_l1_dist, 8.78e-4, atol=1e-3) elif ndim == 2 and height == 16: - assert np.isclose(final_l1_dist, 4.56e-4, atol=1e-5) + assert np.isclose(final_l1_dist, 4.56e-4, atol=1e-4) elif ndim == 4 and height == 8: - assert np.isclose(final_l1_dist, 1.6e-4, atol=1e-5) + assert np.isclose(final_l1_dist, 1.6e-4, atol=1e-4) elif ndim == 4 and height == 16: - assert np.isclose(final_l1_dist, 2.45e-5, atol=1e-6) + assert np.isclose(final_l1_dist, 2.45e-5, atol=1e-5) @pytest.mark.parametrize("ndim", [2, 4]) @@ -85,13 +85,13 @@ def test_discreteebm(ndim: int, alpha: float): args = DiscreteEBMArgs(ndim=ndim, alpha=alpha, n_trajectories=n_trajectories) final_l1_dist = train_discreteebm_main(args) if ndim == 2 and alpha == 0.1: - assert np.isclose(final_l1_dist, 2.97e-3, atol=1e-3) + assert np.isclose(final_l1_dist, 2.97e-3, atol=1e-2) elif ndim == 2 and alpha == 1.0: - assert np.isclose(final_l1_dist, 0.017, atol=1e-3) + assert np.isclose(final_l1_dist, 0.017, atol=1e-2) elif ndim == 4 and alpha == 0.1: - assert np.isclose(final_l1_dist, 0.009, atol=1e-3) + assert np.isclose(final_l1_dist, 0.009, atol=1e-2) elif ndim == 4 and alpha == 1.0: - assert np.isclose(final_l1_dist, 0.062, atol=1e-3) + assert np.isclose(final_l1_dist, 0.062, atol=1e-2) @pytest.mark.parametrize("delta", [0.1, 0.25]) @@ -114,10 +114,10 @@ def test_box(delta: float, loss: str): print(args) final_jsd = train_box_main(args) if loss == "TB" and delta == 0.1: - assert np.isclose(final_jsd, 3.81e-2, atol=1e-3) + assert np.isclose(final_jsd, 3.81e-2, atol=1e-2) elif loss == "DB" and delta == 0.1: - assert np.isclose(final_jsd, 0.134, atol=1e-2) + assert np.isclose(final_jsd, 0.134, atol=1e-1) if loss == "TB" and delta == 0.25: - assert np.isclose(final_jsd, 0.0411, atol=1e-3) + assert np.isclose(final_jsd, 0.0411, atol=1e-2) elif loss == "DB" and delta == 0.25: - assert np.isclose(final_jsd, 0.0142, atol=1e-3) + assert np.isclose(final_jsd, 0.0142, atol=1e-2) From f7a562e2561ea7d7dcad81459114e55113643897 Mon Sep 17 00:00:00 2001 From: Joseph Date: Wed, 29 Nov 2023 12:22:37 -0500 Subject: [PATCH 02/28] replaced State method call with Env method call --- src/gfn/containers/replay_buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index 179c3ad0..eebf24db 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -46,8 +46,8 @@ def __init__( self.training_objects = Transitions(env) self.objects_type = "transitions" elif objects_type == "states": - self.training_objects = env.States.from_batch_shape((0,)) - self.terminating_states = env.States.from_batch_shape((0,)) + self.training_objects = env.states_from_batch_shape((0,)) + self.terminating_states = env.states_from_batch_shape((0,)) self.objects_type = "states" else: raise ValueError(f"Unknown objects_type: {objects_type}") From 742bfb243d9664a4b7edfb41a707a20ff9373f50 Mon Sep 17 00:00:00 2001 From: Joseph Date: Wed, 29 Nov 2023 12:47:18 -0500 Subject: [PATCH 03/28] replaced State method call with Env method call, and removed is_tensor() function, for linter compatibility --- src/gfn/containers/trajectories.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index e2e25f6f..38f926d8 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -16,11 +16,6 @@ from gfn.containers.transitions import Transitions -def is_tensor(t) -> bool: - """Checks whether t is a torch.Tensor instance.""" - return isinstance(t, Tensor) - - # TODO: remove env from this class? class Trajectories(Container): """Container for complete trajectories (starting in $s_0$ and ending in $s_f$). @@ -79,13 +74,13 @@ def __init__( self.states = ( states.clone() # TODO: Do we need this clone? if states is not None - else env.States.from_batch_shape(batch_shape=(0, 0)) + else env.states_from_batch_shape((0, 0)) ) assert len(self.states.batch_shape) == 2 self.actions = ( actions if actions is not None - else env.Actions.make_dummy_actions(batch_shape=(0, 0)) + else env.actions_from_batch_shape((0, 0)) ) assert len(self.actions.batch_shape) == 2 self.when_is_done = ( @@ -168,7 +163,7 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories: self._log_rewards[index] if self._log_rewards is not None else None ) - if is_tensor(self.estimator_outputs): + if isinstance(self.estimator_outputs, Tensor): estimator_outputs = self.estimator_outputs[:, index] estimator_outputs = estimator_outputs[:new_max_length] else: @@ -241,9 +236,9 @@ def extend(self, other: Trajectories) -> None: # Either set, or append, estimator outputs if they exist in the submitted # trajectory. - if self.estimator_outputs is None and is_tensor(other.estimator_outputs): + if self.estimator_outputs is None and isinstance(other.estimator_outputs, Tensor): self.estimator_outputs = other.estimator_outputs - elif is_tensor(self.estimator_outputs) and is_tensor(other.estimator_outputs): + elif isinstance(self.estimator_outputs, Tensor) and isinstance(other.estimator_outputs, Tensor): batch_shape = self.actions.batch_shape n_bs = len(batch_shape) output_dtype = self.estimator_outputs.dtype From f545167245e659fd11886e6bbe897de019de9325 Mon Sep 17 00:00:00 2001 From: Joseph Date: Wed, 29 Nov 2023 12:47:43 -0500 Subject: [PATCH 04/28] replaced State method call with Env method call --- src/gfn/containers/transitions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gfn/containers/transitions.py b/src/gfn/containers/transitions.py index 3019f276..baddfa34 100644 --- a/src/gfn/containers/transitions.py +++ b/src/gfn/containers/transitions.py @@ -68,14 +68,14 @@ def __init__( self.states = ( states if states is not None - else env.States.from_batch_shape(batch_shape=(0,)) + else env.states_from_batch_shape(batch_shape=(0,)) ) assert len(self.states.batch_shape) == 1 self.actions = ( actions if actions is not None - else env.Actions.make_dummy_actions(batch_shape=(0,)) + else env.actions_from_batch_shape((0,)) ) self.is_done = ( is_done @@ -85,7 +85,7 @@ def __init__( self.next_states = ( next_states if next_states is not None - else env.States.from_batch_shape(batch_shape=(0,)) + else env.states_from_batch_shape(batch_shape=(0,)) ) assert ( len(self.next_states.batch_shape) == 1 From f945b3982592e5c3b7e1bcd0b29dccbc5b56b527 Mon Sep 17 00:00:00 2001 From: Joseph Date: Wed, 29 Nov 2023 12:48:52 -0500 Subject: [PATCH 05/28] switch name of backward/forward step, and replaced State method call with Env method call --- src/gfn/gflownet/flow_matching.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index a2acae34..061761d4 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -89,9 +89,9 @@ def flow_matching_loss( backward_actions = torch.full_like( valid_backward_states.backward_masks[:, 0], action_idx, dtype=torch.long ).unsqueeze(-1) - backward_actions = env.Actions(backward_actions) + backward_actions = env.actions_from_tensor(backward_actions) - valid_backward_states_parents = env.backward_step( + valid_backward_states_parents = env._backward_step( valid_backward_states, backward_actions ) From 2a51704e92f7bde08641c394e2b6cdb8997f44f5 Mon Sep 17 00:00:00 2001 From: Joseph Date: Wed, 29 Nov 2023 12:50:15 -0500 Subject: [PATCH 06/28] removed States/Actions class definition, added the appropriate args to be passed to subclasses. --- src/gfn/gym/box.py | 49 +++++++------------ src/gfn/gym/discrete_ebm.py | 94 +++++++++++++------------------------ src/gfn/gym/hypergrid.py | 63 ++++++++++--------------- 3 files changed, 74 insertions(+), 132 deletions(-) diff --git a/src/gfn/gym/box.py b/src/gfn/gym/box.py index d5a899bd..7c070682 100644 --- a/src/gfn/gym/box.py +++ b/src/gfn/gym/box.py @@ -1,5 +1,5 @@ from math import log -from typing import ClassVar, Literal, Tuple +from typing import Literal, Tuple import torch from torchtyping import TensorType as TT @@ -25,49 +25,32 @@ def __init__( self.delta = delta self.epsilon = epsilon s0 = torch.tensor([0.0, 0.0], device=torch.device(device_str)) + exit_action = torch.tensor([-float("inf"), -float("inf")], device=torch.device(device_str)) + dummy_action = torch.tensor([float("inf"), float("inf")], device=torch.device(device_str)) self.R0 = R0 self.R1 = R1 self.R2 = R2 - super().__init__(s0=s0) - - def make_States_class(self) -> type[States]: - env = self - - class BoxStates(States): - state_shape: ClassVar[Tuple[int, ...]] = (2,) - s0 = env.s0 - sf = env.sf # should be (-inf, -inf) - - @classmethod - def make_random_states_tensor( - cls, batch_shape: Tuple[int, ...] - ) -> TT["batch_shape", 2, torch.float]: - return torch.rand(batch_shape + (2,), device=env.device) - - return BoxStates - - def make_Actions_class(self) -> type[Actions]: - env = self - - class BoxActions(Actions): - action_shape: ClassVar[Tuple[int, ...]] = (2,) - dummy_action: ClassVar[TT[2]] = torch.tensor( - [float("inf"), float("inf")], device=env.device - ) - exit_action: ClassVar[TT[2]] = torch.tensor( - [-float("inf"), -float("inf")], device=env.device - ) + super().__init__( + s0=s0, + state_shape=(2,), # () + action_shape=(2,), + dummy_action=dummy_action, + exit_action=exit_action, + ) - return BoxActions + def make_random_states_tensor( + self, batch_shape: Tuple[int, ...] + ) -> TT["batch_shape", 2, torch.float]: + return torch.rand(batch_shape + (2,), device=self.device) - def maskless_step( + def step( self, states: States, actions: Actions ) -> TT["batch_shape", 2, torch.float]: return states.tensor + actions.tensor - def maskless_backward_step( + def backward_step( self, states: States, actions: Actions ) -> TT["batch_shape", 2, torch.float]: return states.tensor - actions.tensor diff --git a/src/gfn/gym/discrete_ebm.py b/src/gfn/gym/discrete_ebm.py index ecd05eea..85495f95 100644 --- a/src/gfn/gym/discrete_ebm.py +++ b/src/gfn/gym/discrete_ebm.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod -from typing import ClassVar, Literal, Tuple +from typing import Literal, Tuple import torch +from torch import Tensor import torch.nn as nn from torchtyping import TensorType as TT @@ -87,69 +88,37 @@ def __init__( raise ValueError(f"Unknown preprocessor {preprocessor_name}") super().__init__( - n_actions=n_actions, s0=s0, + state_shape=(self.ndim, ), + # dummy_action=, + # exit_action=, + n_actions=n_actions, sf=sf, device_str=device_str, preprocessor=preprocessor, ) - def make_States_class(self) -> type[DiscreteStates]: - env = self - - class DiscreteEBMStates(DiscreteStates): - state_shape: ClassVar[tuple[int, ...]] = (env.ndim,) - s0 = env.s0 - sf = env.sf - n_actions = env.n_actions - device = env.device - - @classmethod - def make_random_states_tensor( - cls, batch_shape: Tuple[int, ...] - ) -> TT["batch_shape", "state_shape", torch.float]: - return torch.randint( - -1, - 2, - batch_shape + (env.ndim,), - dtype=torch.long, - device=env.device, - ) - - # TODO: Look into make masks - I don't think this is being called. - def make_masks( - self, - ) -> Tuple[ - TT["batch_shape", "n_actions", torch.bool], - TT["batch_shape", "n_actions - 1", torch.bool], - ]: - forward_masks = torch.zeros( - self.batch_shape + (env.n_actions,), - device=env.device, - dtype=torch.bool, - ) - backward_masks = torch.zeros( - self.batch_shape + (env.n_actions - 1,), - device=env.device, - dtype=torch.bool, - ) - - return forward_masks, backward_masks - - def update_masks(self) -> None: - self.set_default_typing() - self.forward_masks[..., : env.ndim] = self.tensor == -1 - self.forward_masks[..., env.ndim : 2 * env.ndim] = self.tensor == -1 - self.forward_masks[..., -1] = torch.all(self.tensor != -1, dim=-1) - self.backward_masks[..., : env.ndim] = self.tensor == 0 - self.backward_masks[..., env.ndim : 2 * env.ndim] = self.tensor == 1 - - return DiscreteEBMStates + def update_masks(self, states: type[States]) -> None: + states.set_default_typing() + states.forward_masks[..., : self.ndim] = states.tensor == -1 + states.forward_masks[..., self.ndim : 2 * self.ndim] = states.tensor == -1 + states.forward_masks[..., -1] = torch.all(states.tensor != -1, dim=-1) + states.backward_masks[..., : self.ndim] = states.tensor == 0 + states.backward_masks[..., self.ndim : 2 * self.ndim] = states.tensor == 1 + + def make_random_states_tensor(self, batch_shape: Tuple) -> Tensor: + return torch.randint( + -1, + 2, + batch_shape + (self.ndim,), + dtype=torch.long, + device=self.device, + ) def is_exit_actions(self, actions: TT["batch_shape"]) -> TT["batch_shape"]: return actions == self.n_actions - 1 - def maskless_step( + def step( self, states: States, actions: Actions ) -> TT["batch_shape", "state_shape", torch.float]: # First, we select that actions that replace a -1 with a 0. @@ -169,15 +138,18 @@ def maskless_step( ) return states.tensor - def maskless_backward_step( + def backward_step( self, states: States, actions: Actions ) -> TT["batch_shape", "state_shape", torch.float]: - # In this env, states are n-dim vectors. s0 is empty (represented as -1), - # so s0=[-1, -1, ..., -1], each action is replacing a -1 with either a - # 0 or 1. Action i in [0, ndim-1] os replacing s[i] with 0, whereas - # action i in [ndim, 2*ndim-1] corresponds to replacing s[i - ndim] with 1. - # A backward action asks "what index should be set back to -1", hence the fmod - # to enable wrapping of indices. + """Performs a backward step. + + In this env, states are n-dim vectors. s0 is empty (represented as -1), + so s0=[-1, -1, ..., -1], each action is replacing a -1 with either a + 0 or 1. Action i in [0, ndim-1] os replacing s[i] with 0, whereas + action i in [ndim, 2*ndim-1] corresponds to replacing s[i - ndim] with 1. + A backward action asks "what index should be set back to -1", hence the fmod + to enable wrapping of indices. + """ return states.tensor.scatter(-1, actions.tensor.fmod(self.ndim), -1) def reward(self, final_states: DiscreteStates) -> TT["batch_shape"]: diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index 71d2862e..b8bf27d1 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -1,7 +1,7 @@ """ Copied and Adapted from https://github.com/Tikquuss/GflowNets_Tutorial """ -from typing import ClassVar, Literal, Tuple +from typing import Literal, Tuple import torch from einops import rearrange @@ -53,7 +53,6 @@ def __init__( sf = torch.full( (ndim,), fill_value=-1, dtype=torch.long, device=torch.device(device_str) ) - n_actions = ndim + 1 if preprocessor_name == "Identity": @@ -74,55 +73,43 @@ def __init__( else: raise ValueError(f"Unknown preprocessor {preprocessor_name}") + state_shape = (self.ndim,) + super().__init__( n_actions=n_actions, s0=s0, + state_shape=state_shape, sf=sf, device_str=device_str, preprocessor=preprocessor, ) - def make_States_class(self) -> type[DiscreteStates]: - "Creates a States class for this environment" - env = self - - class HyperGridStates(DiscreteStates): - state_shape: ClassVar[tuple[int, ...]] = (env.ndim,) - s0 = env.s0 - sf = env.sf - n_actions = env.n_actions - device = env.device - - @classmethod - def make_random_states_tensor( - cls, batch_shape: Tuple[int, ...] - ) -> TT["batch_shape", "state_shape", torch.float]: - "Creates a batch of random states." - states_tensor = torch.randint( - 0, env.height, batch_shape + env.s0.shape, device=env.device - ) - return states_tensor - - def update_masks(self) -> None: - "Update the masks based on the current states." - self.set_default_typing() - # Not allowed to take any action beyond the environment height, but - # allow early termination. - self.set_nonexit_action_masks( - self.tensor == env.height - 1, - allow_exit=True, - ) - self.backward_masks = self.tensor != 0 - - return HyperGridStates - - def maskless_step( + def update_masks(self, states: type[DiscreteStates]) -> None: + """Update the masks based on the current states.""" + states.set_default_typing() + # Not allowed to take any action beyond the environment height, but + # allow early termination. + states.set_nonexit_action_masks( + states.tensor == self.height - 1, + allow_exit=True, + ) + states.backward_masks = states.tensor != 0 + + def make_random_states_tensor( + self, batch_shape: Tuple[int, ...] + ) -> TT["batch_shape", "state_shape", torch.float]: + """Creates a batch of random states.""" + return torch.randint( + 0, self.height, batch_shape + self.s0.shape, device=self.device + ) + + def step( self, states: DiscreteStates, actions: Actions ) -> TT["batch_shape", "state_shape", torch.float]: new_states_tensor = states.tensor.scatter(-1, actions.tensor, 1, reduce="add") return new_states_tensor - def maskless_backward_step( + def backward_step( self, states: DiscreteStates, actions: Actions ) -> TT["batch_shape", "state_shape", torch.float]: new_states_tensor = states.tensor.scatter(-1, actions.tensor, -1, reduce="add") From cdd425cc57f32ac02642e98381c42068ab2561ac Mon Sep 17 00:00:00 2001 From: Joseph Date: Wed, 29 Nov 2023 12:50:29 -0500 Subject: [PATCH 07/28] moved environment to Gym --- src/gfn/gym/line.py | 90 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 src/gfn/gym/line.py diff --git a/src/gfn/gym/line.py b/src/gfn/gym/line.py new file mode 100644 index 00000000..6b246f94 --- /dev/null +++ b/src/gfn/gym/line.py @@ -0,0 +1,90 @@ +from typing import Literal + +import torch +from torch.distributions import Normal # TODO: extend to Beta +from torchtyping import TensorType as TT + +from gfn.actions import Actions +from gfn.env import Env +from gfn.states import States + + +class Line(Env): + """Mixture of Gaussians Line environment.""" + + def __init__( + self, + mus: list, + sigmas: list, + init_value: float, + n_sd: float = 4.5, + n_steps_per_trajectory: int = 5, + device_str: Literal["cpu", "cuda"] = "cpu", + ): + assert len(mus) == len(sigmas) + self.mus = torch.tensor(mus) + self.sigmas = torch.tensor(sigmas) + self.n_sd = n_sd + self.n_steps_per_trajectory = n_steps_per_trajectory + self.mixture = [Normal(m, s) for m, s in zip(self.mus, self.sigmas)] + + self.init_value = init_value # Used in s0. + self.lb = min(self.mus) - self.n_sd * max(self.sigmas) # Convienience only. + self.ub = max(self.mus) + self.n_sd * max(self.sigmas) # Convienience only. + assert self.lb < self.init_value < self.ub + + s0 = torch.tensor([self.init_value, 0.0], device=torch.device(device_str)) + dummy_action = torch.tensor([float("inf")], device=torch.device(device_str)) + exit_action = torch.tensor([-float("inf")], device=torch.device(device_str)) + super().__init__( + s0=s0, + state_shape=(2,), # [x_pos, step_counter]. + action_shape=(1,), # [x_pos] + dummy_action=dummy_action, + exit_action=exit_action, + ) # sf is -inf by defaukt. + + def step( + self, states: States, actions: Actions + ) -> TT["batch_shape", 2, torch.float]: + states.tensor[..., 0] = states.tensor[..., 0] + actions.tensor.squeeze( + -1 + ) # x position. + states.tensor[..., 1] = states.tensor[..., 1] + 1 # Step counter. + return states.tensor + + def backward_step( + self, states: States, actions: Actions + ) -> TT["batch_shape", 2, torch.float]: + states.tensor[..., 0] = states.tensor[..., 0] - actions.tensor.squeeze( + -1 + ) # x position. + states.tensor[..., 1] = states.tensor[..., 1] - 1 # Step counter. + return states.tensor + + def is_action_valid( + self, states: States, actions: Actions, backward: bool = False + ) -> bool: + # Can't take a backward step at the beginning of a trajectory. + if torch.any(states[~actions.is_exit].is_initial_state) and backward: + return False + + return True + + def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]: + s = final_states.tensor[..., 0] + # return torch.logsumexp(torch.stack([m.log_prob(s) for m in self.mixture], 0), 0) + + # if s.nelement() == 0: + # return torch.zeros(final_states.batch_shape) + + log_rewards = torch.empty((len(self.mixture),) + final_states.batch_shape) + for i, m in enumerate(self.mixture): + log_rewards[i] = m.log_prob(s) + + return torch.logsumexp(log_rewards, 0) + + @property + def log_partition(self) -> float: + """Log Partition log of the number of gaussians.""" + return torch.tensor(len(self.mus)).log() From 6ae846b0a0f3458c12335754b31126688520ff75 Mon Sep 17 00:00:00 2001 From: Joseph Date: Wed, 29 Nov 2023 12:53:43 -0500 Subject: [PATCH 08/28] renamed maskless_?_step functions, and made the generic step/backward_step functions private. this maybe isn't the best solution as they are accessed externally by other elements of the library. mask updating is now handled by the DiscreteEnv. A generic make_States_class and make_Actions_class method is added to both Env and DiscreteEnv. --- src/gfn/env.py | 261 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 210 insertions(+), 51 deletions(-) diff --git a/src/gfn/env.py b/src/gfn/env.py index c21f958b..6440d7fa 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -3,6 +3,7 @@ import torch from torchtyping import TensorType as TT +from torch import Tensor from gfn.actions import Actions from gfn.preprocessors import IdentityPreprocessor, Preprocessor @@ -11,6 +12,9 @@ # Errors NonValidActionsError = type("NonValidActionsError", (ValueError,), {}) +def get_device(device_str, default_device): + return torch.device(device_str) if device_str is not None else default_device + class Env(ABC): """Base class for all environments. Environments require that individual states be represented as a unique tensor of @@ -19,6 +23,10 @@ class Env(ABC): def __init__( self, s0: TT["state_shape", torch.float], + state_shape: Tuple, + action_shape: Tuple, + dummy_action: Tensor, + exit_action: Tensor, sf: Optional[TT["state_shape", torch.float]] = None, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, @@ -28,6 +36,10 @@ def __init__( Args: s0: Representation of the initial state. All individual states would be of the same shape. + state_shape: + action_shape: + dummy_action: + exit_action: sf: Representation of the final state. Only used for a human readable representation of the states or trajectories. device_str: 'cpu' or 'cuda'. Defaults to None, in which case the device is @@ -36,12 +48,16 @@ def __init__( that can be fed into a neural network. Defaults to None, in which case the IdentityPreprocessor is used. """ - self.device = torch.device(device_str) if device_str is not None else s0.device + self.device = get_device(device_str, default_device=s0.device) self.s0 = s0.to(self.device) if sf is None: sf = torch.full(s0.shape, -float("inf")).to(self.device) self.sf = sf + self.state_shape = state_shape + self.action_shape = action_shape + self.dummy_action = dummy_action + self.exit_action = exit_action self.States = self.make_States_class() self.Actions = self.make_Actions_class() @@ -56,14 +72,92 @@ def __init__( self.preprocessor = preprocessor self.is_discrete = False + def states_from_tensor(self, tensor: Tensor): + """Wraps the supplied Tensor in a States instance.""" + return self.States(tensor) + + def states_from_batch_shape(self, batch_shape: Tuple): + """Returns a batch of s0 states with a given batch_shape.""" + return self.States.from_batch_shape(batch_shape) + + def actions_from_tensor(self, tensor: Tensor): + """Wraps the supplied Tensor an an Actions instance.""" + return self.Actions(tensor) + + def actions_from_batch_shape(self, batch_shape: Tuple): + """Returns a batch of dummy actions with the supplied batch_shape.""" + return self.Actions.make_dummy_actions(batch_shape) + + # To be implemented by the User. @abstractmethod - def make_States_class(self) -> type[States]: - """Returns a class that inherits from States and implements the environment-specific methods.""" + def step( + self, states: States, actions: Actions + ) -> TT["batch_shape", "state_shape", torch.float]: + """Function that takes a batch of states and actions and returns a batch of next + states. Does not need to check whether the actions are valid or the states are sink states. + """ + + @abstractmethod + def backward_step( # TODO: rename to backward_step, other method becomes _backward_step. + self, states: States, actions: Actions + ) -> TT["batch_shape", "state_shape", torch.float]: + """Function that takes a batch of states and actions and returns a batch of previous + states. Does not need to check whether the actions are valid or the states are sink states. + """ @abstractmethod + def is_action_valid( + self, + states: States, + actions: Actions, + backward: bool = False, + ) -> bool: + """Returns True if the actions are valid in the given states.""" + + def make_random_states_tensor(self, batch_shape: Tuple) -> Tensor: + """Optional method inherited by all States instances to emit a random tensor.""" + raise NotImplementedError + + # Optionally implemented by the user when advanced functionality is required. + def make_States_class(self) -> type[States]: + """The default States class factory for all Environments. + + Returns a class that inherits from States and implements assumed methods. + The make_States_class method should be overwritten to achieve more + environment-specific States functionality. + """ + env = self + + class DefaultEnvState(States): + """Defines a States class for this environment.""" + state_shape = env.state_shape + s0 = env.s0 + sf = env.sf + make_random_states_tensor = env.make_random_states_tensor + + # @classmethod + # def make_random_states_tensor(cls, batch_shape: Tuple) -> Tensor: + # return env.make_random_states_tensor(batch_shape) + + return DefaultEnvState + def make_Actions_class(self) -> type[Actions]: - """Returns a class that inherits from Actions and implements the environment-specific methods.""" + """The default Actions class factory for all Environments. + + Returns a class that inherits from Actions and implements assumed methods. + The make_Actions_class method should be overwritten to achieve more + environment-specific Actions functionality. + """ + env = self + class DefaultEnvAction(Actions): + action_shape = env.action_shape + dummy_action = env.dummy_action + exit_action = env.exit_action + + return DefaultEnvAction + + # In some cases overwritten by the user to support specific use-cases. def reset( self, batch_shape: Optional[Union[int, Tuple[int]]] = None, @@ -89,31 +183,6 @@ def reset( batch_shape=batch_shape, random=random, sink=sink ) - @abstractmethod - def maskless_step( # TODO: rename to step, other method becomes _step. - self, states: States, actions: Actions - ) -> TT["batch_shape", "state_shape", torch.float]: - """Function that takes a batch of states and actions and returns a batch of next - states. Does not need to check whether the actions are valid or the states are sink states. - """ - - @abstractmethod - def maskless_backward_step( # TODO: rename to backward_step, other method becomes _backward_step. - self, states: States, actions: Actions - ) -> TT["batch_shape", "state_shape", torch.float]: - """Function that takes a batch of states and actions and returns a batch of previous - states. Does not need to check whether the actions are valid or the states are sink states. - """ - - @abstractmethod - def is_action_valid( - self, - states: States, - actions: Actions, - backward: bool = False, - ) -> bool: - """Returns True if the actions are valid in the given states.""" - def validate_actions( self, states: States, actions: Actions, backward: bool = False ) -> bool: @@ -123,13 +192,16 @@ def validate_actions( assert states.batch_shape == actions.batch_shape return self.is_action_valid(states, actions, backward) - def step( + def _step( self, states: States, actions: Actions, ) -> States: - """Function that takes a batch of states and actions and returns a batch of next - states and a boolean tensor indicating sink states in the new batch.""" + """Core step function. Calls the user-defined self.step() function. + + Function that takes a batch of states and actions and returns a batch of next + states and a boolean tensor indicating sink states in the new batch. + """ new_states = states.clone() # TODO: Ensure this is efficient! valid_states_idx: TT["batch_shape", torch.bool] = ~states.is_sink_state valid_actions = actions[valid_states_idx] @@ -147,7 +219,7 @@ def step( not_done_states = new_states[~new_sink_states_idx] not_done_actions = actions[~new_sink_states_idx] - new_not_done_states_tensor = self.maskless_step( + new_not_done_states_tensor = self.step( not_done_states, not_done_actions ) # TODO: Why is this here? Should it be removed? @@ -158,13 +230,16 @@ def step( return new_states - def backward_step( + def _backward_step( self, states: States, actions: Actions, ) -> States: - """Function that takes a batch of states and actions and returns a batch of next - states and a boolean tensor indicating initial states in the new batch.""" + """Core backward_step function. Calls the user-defined self.backward_step fn. + + This function takes a batch of states and actions and returns a batch of next + states and a boolean tensor indicating initial states in the new batch. + """ new_states = states.clone() # TODO: Ensure this is efficient! valid_states_idx: TT["batch_shape", torch.bool] = ~new_states.is_initial_state valid_actions = actions[valid_states_idx] @@ -176,13 +251,13 @@ def backward_step( ) # Calculate the backward step, and update only the states which are not Done. - new_not_done_states_tensor = self.maskless_backward_step( + new_not_done_states_tensor = self.backward_step( valid_states, valid_actions ) new_states.tensor[valid_states_idx] = new_not_done_states_tensor if isinstance(new_states, DiscreteStates): - new_states.update_masks() + self.update_masks(new_states) return new_states @@ -218,6 +293,10 @@ def __init__( self, n_actions: int, s0: TT["state_shape", torch.float], + state_shape: Tuple, + # action_shape: Tuple, # TODO: Remove? I feel like we might need this. + dummy_action: Optional[TT["action_shape", torch.long]] = None, + exit_action: Optional[TT["action_shape", torch.long]] = None, sf: Optional[TT["state_shape", torch.float]] = None, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, @@ -227,22 +306,104 @@ def __init__( Args: n_actions: The number of actions in the environment. s0: The initial state tensor (shared among all trajectories). + state_shape: + action_shape: ? + dummy_action: The value of the dummy (padding) action. + exit_action: The value of the exit action. sf: The final state tensor (shared among all trajectories). device_str: String representation of a torch.device. preprocessor: An optional preprocessor for intermediate states. """ - self.n_actions = n_actions - super().__init__(s0, sf, device_str, preprocessor) - self.is_discrete = True + device = get_device(device_str, default_device=s0.device) + + # The default dummy action is -1. + if isinstance(dummy_action, type(None)): + dummy_action = torch.tensor([-1], device=device) + + # The default exit action index is the final element of the action space. + if isinstance(exit_action, type(None)): + exit_action = torch.tensor([n_actions - 1], device=device) + + self.n_actions = n_actions # Before init, for compatibility with States. + + super().__init__( + s0, + state_shape, + (1,), # The action_shape is always 1. TODO: is it? + dummy_action, + exit_action, + sf, + device_str, + preprocessor, + ) + + self.is_discrete = True # After init, else it will be overwritten. + + def states_from_tensor(self, tensor: Tensor): + """Wraps the supplied Tensor in a States instance & updates masks.""" + states_instance = self.make_States_class()(tensor) + self.update_masks(states_instance) + return states_instance + + # In some cases overwritten by the user to support specific use-cases. + def reset( + self, + batch_shape: Optional[Union[int, Tuple[int]]] = None, + random: bool = False, + sink: bool = False, + seed: int = None, + ) -> States: + """Instantiates a batch of initial states. + + `random` and `sink` cannot be both True. When `random` is `True` and `seed` is + not `None`, environment randomization is fixed by the submitted seed for + reproducibility. + """ + assert not (random and sink) + + if random and seed is not None: + torch.manual_seed(seed) # TODO: Improve seeding here? + + if batch_shape is None: + batch_shape = (1,) + if isinstance(batch_shape, int): + batch_shape = (batch_shape,) + states = self.States.from_batch_shape( + batch_shape=batch_shape, random=random, sink=sink + ) + self.update_masks(states) + + return states + + @abstractmethod + def update_masks(self, states: type[States]) -> None: + """Updates the masks in States. + + Called automatically after each step for discrete environments. + """ + + def make_States_class(self) -> type[States]: + env = self + + class DiscreteEnvStates(DiscreteStates): + + state_shape = env.state_shape + s0 = env.s0 + sf = env.sf + make_random_states_tensor = env.make_random_states_tensor + n_actions = env.n_actions + device = env.device + + return DiscreteEnvStates def make_Actions_class(self) -> type[Actions]: env = self n_actions = self.n_actions class DiscreteEnvActions(Actions): - action_shape = (1,) - dummy_action = torch.tensor([-1], device=env.device) # Double check - exit_action = torch.tensor([n_actions - 1], device=env.device) + action_shape = env.action_shape + dummy_action = env.dummy_action.to(device=env.device) + exit_action = env.exit_action.to(device=env.device) return DiscreteEnvActions @@ -253,13 +414,10 @@ def is_action_valid( masks_tensor = states.backward_masks if backward else states.forward_masks return torch.gather(masks_tensor, 1, actions.tensor).all() - def step( - self, - states: DiscreteStates, - actions: Actions, - ) -> States: - new_states = super().step(states, actions) - new_states.update_masks() + def _step(self, states: DiscreteStates, actions: Actions) -> States: + """Calls the core self._step method of the parent class, and updates masks.""" + new_states = super()._step(states, actions) + self.update_masks(new_states) # TODO: update_masks is owned by the env, not the states!! return new_states def get_states_indices( @@ -316,3 +474,4 @@ def terminating_states(self) -> DiscreteStates: return NotImplementedError( "The environment does not support enumeration of states" ) + From a09c9a548a6c0c1b2721479777c2e08b90d8c9e5 Mon Sep 17 00:00:00 2001 From: Joseph Date: Wed, 29 Nov 2023 12:53:57 -0500 Subject: [PATCH 09/28] removed comment --- src/gfn/env.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/gfn/env.py b/src/gfn/env.py index 6440d7fa..d8b681c8 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -135,10 +135,6 @@ class DefaultEnvState(States): sf = env.sf make_random_states_tensor = env.make_random_states_tensor - # @classmethod - # def make_random_states_tensor(cls, batch_shape: Tuple) -> Tensor: - # return env.make_random_states_tensor(batch_shape) - return DefaultEnvState def make_Actions_class(self) -> type[Actions]: From 93f6a65d639dba870502ee956db50872e7bc6986 Mon Sep 17 00:00:00 2001 From: Joseph Date: Wed, 29 Nov 2023 12:54:50 -0500 Subject: [PATCH 10/28] States methods moved to Env methods, also, name change for step --- src/gfn/samplers.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 92781664..10755915 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -78,7 +78,7 @@ def sample_actions( else: log_probs = None - actions = env.Actions(actions) + actions = env.actions_from_tensor(actions) if not save_estimator_outputs: estimator_output = None @@ -156,9 +156,7 @@ def sample_trajectories( all_estimator_outputs = [] while not all(dones): - actions = env.Actions.make_dummy_actions( - batch_shape=(n_trajectories,) - ) # TODO: Why do we need this? + actions = env.actions_from_batch_shape((n_trajectories,)) # Dummy actions. log_probs = torch.full( (n_trajectories,), fill_value=0, dtype=torch.float, device=device ) @@ -186,17 +184,16 @@ def sample_trajectories( all_estimator_outputs.append(estimator_outputs_padded) actions[~dones] = valid_actions - if ( - not skip_logprob_calculaion - ): # When off_policy, actions_log_probs are None. + if not skip_logprob_calculaion: + # When off_policy, actions_log_probs are None. log_probs[~dones] = actions_log_probs trajectories_actions += [actions] trajectories_logprobs += [log_probs] if self.estimator.is_backward: - new_states = env.backward_step(states, actions) + new_states = env._backward_step(states, actions) else: - new_states = env.step(states, actions) + new_states = env._step(states, actions) sink_states_mask = new_states.is_sink_state # Increment the step, determine which trajectories are finisihed, and eval @@ -222,7 +219,7 @@ def sample_trajectories( trajectories_states += [states.tensor] trajectories_states = torch.stack(trajectories_states, dim=0) - trajectories_states = env.States(tensor=trajectories_states) + trajectories_states = env.states_from_tensor(trajectories_states) trajectories_actions = env.Actions.stack(trajectories_actions) trajectories_logprobs = torch.stack(trajectories_logprobs, dim=0) From 2ab5885dbe8104d542f06e229c35212a00f267b5 Mon Sep 17 00:00:00 2001 From: Joseph Date: Wed, 29 Nov 2023 12:57:37 -0500 Subject: [PATCH 11/28] changes to the handling of forward / backward masks. in addition, make_random_state_tensor is now a function passed to the States class as inheritance can no longer be relied on to overwrite the default method. --- src/gfn/states.py | 44 +++++++++++++++++--------------------------- 1 file changed, 17 insertions(+), 27 deletions(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index f5d63a4e..3fd209d4 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from math import prod -from typing import ClassVar, Optional, Sequence, cast +from typing import ClassVar, Optional, Sequence, cast, Callable import torch from torchtyping import TensorType as TT @@ -49,6 +49,7 @@ class States(ABC): sf: ClassVar[ TT["state_shape", torch.float] ] # Dummy state, used to pad a batch of states + make_random_states_tensor: Callable = lambda x: (_ for _ in ()).throw(NotImplementedError("The environment does not support initialization of random states.")) def __init__(self, tensor: TT["batch_shape", "state_shape"]): """Initalize the State container with a batch of states. @@ -101,15 +102,6 @@ def make_initial_states_tensor( assert cls.s0 is not None and state_ndim is not None return cls.s0.repeat(*batch_shape, *((1,) * state_ndim)) - @classmethod - def make_random_states_tensor( - cls, batch_shape: tuple[int] - ) -> TT["batch_shape", "state_shape", torch.float]: - """Makes a tensor with a `batch_shape` of random states, placeholder.""" - raise NotImplementedError( - "The environment does not support initialization of random states." - ) - @classmethod def make_sink_states_tensor( cls, batch_shape: tuple[int] @@ -133,7 +125,7 @@ def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> States: """Access particular states of the batch.""" return self.__class__( self.tensor[index] - ) # TODO: Inefficient - this make a copy of the tensor! + ) # TODO: Inefficient - this makes a copy of the tensor! def __setitem__( self, index: int | Sequence[int] | Sequence[bool], states: States @@ -275,7 +267,6 @@ class DiscreteStates(States, ABC): forward_masks: A boolean tensor of allowable forward policy actions. backward_masks: A boolean tensor of allowable backward policy actions. """ - n_actions: ClassVar[int] device: ClassVar[torch.device] @@ -285,32 +276,35 @@ def __init__( forward_masks: Optional[TT["batch_shape", "n_actions", torch.bool]] = None, backward_masks: Optional[TT["batch_shape", "n_actions - 1", torch.bool]] = None, ) -> None: + """Initalize a DiscreteStates container with a batch of states and masks. Args: tensor: A batch of states. - forward_masks (optional): Initializes a boolean tensor of allowable forward - policy actions. - backward_masks (optional): Initializes a boolean tensor of allowable backward - policy actions. + forward_masks: Initializes a boolean tensor of allowable forward policy + actions. + backward_masks: Initializes a boolean tensor of allowable backward policy + actions. """ super().__init__(tensor) - if forward_masks is None and backward_masks is None: - self.forward_masks = torch.ones( + # In the usual case, no masks are provided and we produce these defaults. + # Note: this **must** be updated externally by the env. + if isinstance(forward_masks, type(None)): + forward_masks = torch.ones( (*self.batch_shape, self.__class__.n_actions), dtype=torch.bool, device=self.__class__.device, ) - self.backward_masks = torch.ones( + if isinstance(backward_masks, type(None)): + backward_masks = torch.ones( (*self.batch_shape, self.__class__.n_actions - 1), dtype=torch.bool, device=self.__class__.device, ) - self.update_masks() - else: - self.forward_masks = cast(torch.Tensor, forward_masks) - self.backward_masks = cast(torch.Tensor, backward_masks) + # Ensures typecasting is consistent no matter what is submitted to init. + self.forward_masks = cast(torch.Tensor, forward_masks) # TODO: Required? + self.backward_masks = cast(torch.Tensor, backward_masks) # TODO: Required? self.set_default_typing() def clone(self) -> States: @@ -332,10 +326,6 @@ def set_default_typing(self) -> None: self.backward_masks, ) - @abstractmethod - def update_masks(self) -> None: - """Updates the masks, called after each action is taken.""" - def _check_both_forward_backward_masks_exist(self): assert self.forward_masks is not None and self.backward_masks is not None From bfd6bbfd11404c176608a54a814048f92b740846 Mon Sep 17 00:00:00 2001 From: Joseph Date: Wed, 29 Nov 2023 12:58:11 -0500 Subject: [PATCH 12/28] method renaming --- testing/test_environments.py | 51 ++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 29 deletions(-) diff --git a/testing/test_environments.py b/testing/test_environments.py index 9dff4d6d..b110baac 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -7,11 +7,6 @@ # Utilities. -def format_actions(a, env): - """Returns a Actions instance from a [batch_size, *action_shape] tensor of actions.""" - return env.Actions(a) - - def format_tensor(list_, discrete=True): """ If discrete, returns a long tensor with a singleton batch dimension from list @@ -89,13 +84,13 @@ def test_HyperGrid_fwd_step_with_preprocessors( failing_actions_list = [2, 0, 1] for actions_list in passing_actions_lists: - actions = format_actions(format_tensor(actions_list), env) - states = env.step(states, actions) + actions = env.actions_from_tensor(format_tensor(actions_list)) + states = env._step(states, actions) # Step 4 fails due an invalid input action. - actions = format_actions(format_tensor(failing_actions_list), env) + actions = env.actions_from_tensor(format_tensor(failing_actions_list)) with pytest.raises(NonValidActionsError): - states = env.step(states, actions) + states = env._step(states, actions) expected_rewards = torch.tensor([0.6, 0.1, 0.6]) assert (torch.round(env.reward(states), decimals=7) == expected_rewards).all() @@ -129,14 +124,14 @@ def test_HyperGrid_bwd_step_with_preprocessors( # All passing actions complete sucessfully. for passing_actions_list in passing_actions_lists: - actions = format_actions(format_tensor(passing_actions_list), env) - states = env.backward_step(states, actions) + actions = env.actions_from_tensor(format_tensor(passing_actions_list)) + states = env._backward_step(states, actions) # Fails due to an invalid input action. states = env.reset(batch_shape=(NDIM, ENV_HEIGHT), random=True, seed=SEED) - failing_actions = format_actions(format_tensor(failing_actions_list), env) + failing_actions = env.actions_from_tensor(format_tensor(failing_actions_list)) with pytest.raises(NonValidActionsError): - states = env.backward_step(states, failing_actions) + states = env._backward_step(states, failing_actions) def test_DiscreteEBM_fwd_step(): @@ -156,18 +151,18 @@ def test_DiscreteEBM_fwd_step(): ] # Only next possible move is [4, 4, 4, 4], for actions_list in passing_actions_lists: - actions = format_actions(format_tensor(actions_list), env) - states = env.step(states, actions) + actions = env.actions_from_tensor(format_tensor(actions_list)) + states = env._step(states, actions) # Step 4 fails due an invalid input action (15 is not possible). - actions = format_actions(format_tensor([4, 15, 4, 4]), env) + actions = env.actions_from_tensor(format_tensor([4, 15, 4, 4])) with pytest.raises(RuntimeError): - states = env.step(states, actions) + states = env._step(states, actions) # Step 5 fails due an invalid input action (1 is possible but not in this state). - actions = format_actions(format_tensor([1, 4, 4, 4]), env) + actions = env.actions_from_tensor(format_tensor([1, 4, 4, 4])) with pytest.raises(NonValidActionsError): - states = env.step(states, actions) + states = env._step(states, actions) expected_rewards = torch.tensor([1, 1, 54.5982, 1]) assert (torch.round(env.reward(states), decimals=4) == expected_rewards).all() @@ -188,15 +183,15 @@ def test_DiscreteEBM_bwd_step(): ] # All passing actions complete sucessfully. for passing_actions_list in passing_actions_lists: - actions = format_actions(format_tensor(passing_actions_list), env) - states = env.backward_step(states, actions) + actions = env.actions_from_tensor(format_tensor(passing_actions_list)) + states = env._backward_step(states, actions) # Fails due to an invalid input action. failing_actions_list = [0, 0, 0] states = env.reset(batch_shape=BATCH_SIZE, random=True, seed=SEED) - failing_actions = format_actions(format_tensor(failing_actions_list), env) + failing_actions = env.actions_from_tensor(format_tensor(failing_actions_list)) with pytest.raises(NonValidActionsError): - states = env.backward_step(states, failing_actions) + states = env._backward_step(states, failing_actions) @pytest.mark.parametrize("delta", [0.1, 0.5, 1.0]) @@ -214,11 +209,9 @@ def test_box_fwd_step(delta: float): ] for failing_actions_list in failing_actions_lists_at_s0: - actions = format_actions( - format_tensor(failing_actions_list, discrete=False), env - ) + actions = env.actions_from_tensor(format_tensor(failing_actions_list, discrete=False)) with pytest.raises(NonValidActionsError): - states = env.step(states, actions) + states = env._step(states, actions) # Trying the step function starting from 3 instances of s_0 A, B = None, None @@ -244,8 +237,8 @@ def test_box_fwd_step(delta: float): actions_tensor[B - A < 0] = torch.tensor([-float("inf"), -float("inf")]) actions_list = actions_tensor.tolist() - actions = format_actions(format_tensor(actions_list, discrete=False), env) - states = env.step(states, actions) + actions = env.actions_from_tensor(format_tensor(actions_list, discrete=False)) + states = env._step(states, actions) states_tensor = states.tensor # The following evaluate the maximum angles of the possible actions From d6d30fee066b266b7636bba173dc791ac33aeabe Mon Sep 17 00:00:00 2001 From: Joseph Date: Wed, 29 Nov 2023 12:58:49 -0500 Subject: [PATCH 13/28] docs update (TOOD: this might need a full rework) --- tutorials/ENV.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/ENV.md b/tutorials/ENV.md index 3fb30718..67f4a5ea 100644 --- a/tutorials/ENV.md +++ b/tutorials/ENV.md @@ -11,7 +11,7 @@ The user needs to implement the following two abstract functions: - The method `make_Actions_class` that creates a subclass of [`Actions`](https://github.com/saleml/torchgfn/tree/master/src/gfn/actions.py), simply by specifying the required class variables (the shape of an action tensor, the dummy action, and the exit action). This method is implemented by default for all `DiscreteEnv`s. -The logic of the environment is handled by the methods `maskless_step` and `maskless_backward_step`, that need to be implemented, which specify how an action changes a state (going forward and backward). These functions do not need to handle masking for discrete environments, nor checking whether actions are allowed, nor checking whether a state is the sink state, etc... These checks are handled in `Env.step` and `Env.backward_step` functions, that do not need to be implemented. Non discrete environments need to implement the `is_action_valid` function however, taking a batch of states and actions, and returning `True` only if all actions can be taken at the given states. +The logic of the environment is handled by the methods `maskless_step` and `maskless_backward_step`, that need to be implemented, which specify how an action changes a state (going forward and backward). These functions do not need to handle masking for discrete environments, nor checking whether actions are allowed, nor checking whether a state is the sink state, etc... These checks are handled in `Env._step` and `Env._backward_step` functions, that are not implemented by the user. Non-discrete environments must implement the `is_action_valid` function, taking a batch of states and actions, and returning `True` only if all actions can be taken at the given states. - The `log_reward` function that assigns the logarithm of a nonnegative reward to every terminating state (i.e. state with all $s_f$ as a child in the DAG). If `log_reward` is not implemented, `reward` needs to be. From 25b75271646cb01e60480352b5a0381cff0f427f Mon Sep 17 00:00:00 2001 From: Joseph Date: Wed, 29 Nov 2023 12:59:27 -0500 Subject: [PATCH 14/28] changes to support new API --- tutorials/examples/train_discreteebm.py | 2 +- tutorials/examples/train_hypergrid.py | 2 +- tutorials/examples/train_line.py | 118 ++---------------------- 3 files changed, 10 insertions(+), 112 deletions(-) diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py index 68b1ba9f..5fdb2591 100644 --- a/tutorials/examples/train_discreteebm.py +++ b/tutorials/examples/train_discreteebm.py @@ -65,7 +65,7 @@ def main(args): # noqa: C901 # 4. Train the gflownet - visited_terminating_states = env.States.from_batch_shape((0,)) + visited_terminating_states = env.states_from_batch_shape((0,)) states_visited = 0 n_iterations = args.n_trajectories // args.batch_size diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 113df50f..f8982727 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -219,7 +219,7 @@ def main(args): # noqa: C901 optimizer = torch.optim.Adam(params) - visited_terminating_states = env.States.from_batch_shape((0,)) + visited_terminating_states = env.states_from_batch_shape((0,)) states_visited = 0 n_iterations = args.n_trajectories // args.batch_size diff --git a/tutorials/examples/train_line.py b/tutorials/examples/train_line.py index 3d0042e5..744e9294 100644 --- a/tutorials/examples/train_line.py +++ b/tutorials/examples/train_line.py @@ -1,6 +1,3 @@ -import random -from typing import ClassVar, Literal, Tuple - import matplotlib.pyplot as plt import numpy as np import torch @@ -9,113 +6,14 @@ from torchtyping import TensorType as TT from tqdm import trange -from gfn.actions import Actions -from gfn.env import Env from gfn.gflownet import TBGFlowNet # TODO: Extend to SubTBGFlowNet from gfn.modules import GFNModule from gfn.states import States from gfn.utils import NeuralNet - +from gfn.gym.line import Line from gfn.utils.common import set_seed -class Line(Env): - """Mixture of Gaussians Line environment.""" - - def __init__( - self, - mus: list, - sigmas: list, - init_value: float, - n_sd: float = 4.5, - n_steps_per_trajectory: int = 5, - device_str: Literal["cpu", "cuda"] = "cpu", - ): - assert len(mus) == len(sigmas) - self.mus = torch.tensor(mus) - self.sigmas = torch.tensor(sigmas) - self.n_sd = n_sd - self.n_steps_per_trajectory = n_steps_per_trajectory - self.mixture = [Normal(m, s) for m, s in zip(self.mus, self.sigmas)] - - self.init_value = init_value # Used in s0. - self.lb = min(self.mus) - self.n_sd * max(self.sigmas) # Convienience only. - self.ub = max(self.mus) + self.n_sd * max(self.sigmas) # Convienience only. - assert self.lb < self.init_value < self.ub - - s0 = torch.tensor([self.init_value, 0.0], device=torch.device(device_str)) - super().__init__(s0=s0) # sf is -inf. - - def make_States_class(self) -> type[States]: - env = self - - class LineStates(States): - state_shape: ClassVar[Tuple[int, ...]] = (2,) - s0 = env.s0 # should be [init x value, 0]. - sf = env.sf # should be [-inf, -inf]. - - return LineStates - - def make_Actions_class(self) -> type[Actions]: - env = self - - class LineActions(Actions): - action_shape: ClassVar[Tuple[int, ...]] = (1,) # Does not include counter! - dummy_action: ClassVar[TT[2]] = torch.tensor( - [float("inf")], device=env.device - ) - exit_action: ClassVar[TT[2]] = torch.tensor( - [-float("inf")], device=env.device - ) - - return LineActions - - def maskless_step( - self, states: States, actions: Actions - ) -> TT["batch_shape", 2, torch.float]: - states.tensor[..., 0] = states.tensor[..., 0] + actions.tensor.squeeze( - -1 - ) # x position. - states.tensor[..., 1] = states.tensor[..., 1] + 1 # Step counter. - return states.tensor - - def maskless_backward_step( - self, states: States, actions: Actions - ) -> TT["batch_shape", 2, torch.float]: - states.tensor[..., 0] = states.tensor[..., 0] - actions.tensor.squeeze( - -1 - ) # x position. - states.tensor[..., 1] = states.tensor[..., 1] - 1 # Step counter. - return states.tensor - - def is_action_valid( - self, states: States, actions: Actions, backward: bool = False - ) -> bool: - # Can't take a backward step at the beginning of a trajectory. - if torch.any(states[~actions.is_exit].is_initial_state) and backward: - return False - - return True - - def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]: - s = final_states.tensor[..., 0] - # return torch.logsumexp(torch.stack([m.log_prob(s) for m in self.mixture], 0), 0) - - # if s.nelement() == 0: - # return torch.zeros(final_states.batch_shape) - - log_rewards = torch.empty((len(self.mixture),) + final_states.batch_shape) - for i, m in enumerate(self.mixture): - log_rewards[i] = m.log_prob(s) - - return torch.logsumexp(log_rewards, 0) - - @property - def log_partition(self) -> float: - """Log Partition log of the number of gaussians.""" - return torch.tensor(len(self.mus)).log() - - def render(env, validation_samples=None): """Renders the reward distribution over the 1D env.""" x = np.linspace( @@ -125,7 +23,7 @@ def render(env, validation_samples=None): ) # Get the rewards from our environment. - r = env.States( + r = env.states_from_tensor( torch.tensor(np.stack((x, torch.ones(len(x))), 1)) # Add dummy state counter. ) d = torch.exp(env.log_reward(r)) # Plots the reward, not the log reward. @@ -336,11 +234,11 @@ def train( loss.backward() # Gradient Clipping. - for p in gflownet.parameters(): - if p.ndim > 0 and p.grad is not None: # We do not clip logZ grad. - p.grad.data.clamp_( - -gradient_clip_value, gradient_clip_value - ).nan_to_num_(0.0) + # for p in gflownet.parameters(): + # if p.ndim > 0 and p.grad is not None: # We do not clip logZ grad. + # p.grad.data.clamp_( + # -gradient_clip_value, gradient_clip_value + # ).nan_to_num_(0.0) optimizer.step() states_visited += len(trajectories) @@ -392,7 +290,7 @@ def train( policy_std_max=policy_std_max, ) pb = StepEstimator(environment, pb_module, backward=True) - gflownet = TBGFlowNet(pf=pf, pb=pb, off_policy=False, init_logZ=0.0) + gflownet = TBGFlowNet(pf=pf, pb=pb, off_policy=True, init_logZ=0.0) gflownet = train( gflownet, From f12cbec443af7f3c8455af096e066c7646792d0d Mon Sep 17 00:00:00 2001 From: Joseph Date: Wed, 29 Nov 2023 13:00:36 -0500 Subject: [PATCH 15/28] tweaks (TODO: fix in follow up PR) --- tutorials/notebooks/intro_gfn_smiley.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/notebooks/intro_gfn_smiley.ipynb b/tutorials/notebooks/intro_gfn_smiley.ipynb index 3ab820b2..8853b00b 100644 --- a/tutorials/notebooks/intro_gfn_smiley.ipynb +++ b/tutorials/notebooks/intro_gfn_smiley.ipynb @@ -1728,7 +1728,7 @@ "def train(gflownet, optimizer, env, batch_size = 128, n_episodes = 25_000):\n", " \"\"\"Training loop, keeping track of terminal states over training.\"\"\"\n", " # This stores example terminating states.\n", - " visited_terminating_states = env.States.from_batch_shape((0,))\n", + " visited_terminating_states = env.states_from_batch_shape((0,))\n", " states_visited = 0\n", " losses = []\n", "\n", From cdffab1e0b822af73a84f4a317f2dd26bb514c7b Mon Sep 17 00:00:00 2001 From: Joseph Date: Wed, 29 Nov 2023 13:02:01 -0500 Subject: [PATCH 16/28] black / isort --- src/gfn/containers/trajectories.py | 12 +++++++----- src/gfn/containers/transitions.py | 4 +--- src/gfn/env.py | 20 +++++++++----------- src/gfn/gflownet/base.py | 2 +- src/gfn/gym/box.py | 12 ++++++++---- src/gfn/gym/discrete_ebm.py | 4 ++-- src/gfn/states.py | 10 +++++++--- tutorials/examples/test_scripts.py | 2 +- tutorials/examples/train_box.py | 4 +--- tutorials/examples/train_discreteebm.py | 8 ++------ tutorials/examples/train_hypergrid.py | 8 ++++---- tutorials/examples/train_line.py | 9 +++++---- 12 files changed, 48 insertions(+), 47 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 38f926d8..2b42eb2c 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -78,9 +78,7 @@ def __init__( ) assert len(self.states.batch_shape) == 2 self.actions = ( - actions - if actions is not None - else env.actions_from_batch_shape((0, 0)) + actions if actions is not None else env.actions_from_batch_shape((0, 0)) ) assert len(self.actions.batch_shape) == 2 self.when_is_done = ( @@ -236,9 +234,13 @@ def extend(self, other: Trajectories) -> None: # Either set, or append, estimator outputs if they exist in the submitted # trajectory. - if self.estimator_outputs is None and isinstance(other.estimator_outputs, Tensor): + if self.estimator_outputs is None and isinstance( + other.estimator_outputs, Tensor + ): self.estimator_outputs = other.estimator_outputs - elif isinstance(self.estimator_outputs, Tensor) and isinstance(other.estimator_outputs, Tensor): + elif isinstance(self.estimator_outputs, Tensor) and isinstance( + other.estimator_outputs, Tensor + ): batch_shape = self.actions.batch_shape n_bs = len(batch_shape) output_dtype = self.estimator_outputs.dtype diff --git a/src/gfn/containers/transitions.py b/src/gfn/containers/transitions.py index baddfa34..4b15f05e 100644 --- a/src/gfn/containers/transitions.py +++ b/src/gfn/containers/transitions.py @@ -73,9 +73,7 @@ def __init__( assert len(self.states.batch_shape) == 1 self.actions = ( - actions - if actions is not None - else env.actions_from_batch_shape((0,)) + actions if actions is not None else env.actions_from_batch_shape((0,)) ) self.is_done = ( is_done diff --git a/src/gfn/env.py b/src/gfn/env.py index d8b681c8..7d79def5 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -2,8 +2,8 @@ from typing import Optional, Tuple, Union import torch -from torchtyping import TensorType as TT from torch import Tensor +from torchtyping import TensorType as TT from gfn.actions import Actions from gfn.preprocessors import IdentityPreprocessor, Preprocessor @@ -12,6 +12,7 @@ # Errors NonValidActionsError = type("NonValidActionsError", (ValueError,), {}) + def get_device(device_str, default_device): return torch.device(device_str) if device_str is not None else default_device @@ -130,6 +131,7 @@ def make_States_class(self) -> type[States]: class DefaultEnvState(States): """Defines a States class for this environment.""" + state_shape = env.state_shape s0 = env.s0 sf = env.sf @@ -215,9 +217,7 @@ def _step( not_done_states = new_states[~new_sink_states_idx] not_done_actions = actions[~new_sink_states_idx] - new_not_done_states_tensor = self.step( - not_done_states, not_done_actions - ) + new_not_done_states_tensor = self.step(not_done_states, not_done_actions) # TODO: Why is this here? Should it be removed? # if isinstance(new_states, DiscreteStates): # new_not_done_states.masks = self.update_masks(not_done_states, not_done_actions) @@ -247,9 +247,7 @@ def _backward_step( ) # Calculate the backward step, and update only the states which are not Done. - new_not_done_states_tensor = self.backward_step( - valid_states, valid_actions - ) + new_not_done_states_tensor = self.backward_step(valid_states, valid_actions) new_states.tensor[valid_states_idx] = new_not_done_states_tensor if isinstance(new_states, DiscreteStates): @@ -316,7 +314,7 @@ def __init__( if isinstance(dummy_action, type(None)): dummy_action = torch.tensor([-1], device=device) - # The default exit action index is the final element of the action space. + # The default exit action index is the final element of the action space. if isinstance(exit_action, type(None)): exit_action = torch.tensor([n_actions - 1], device=device) @@ -382,7 +380,6 @@ def make_States_class(self) -> type[States]: env = self class DiscreteEnvStates(DiscreteStates): - state_shape = env.state_shape s0 = env.s0 sf = env.sf @@ -413,7 +410,9 @@ def is_action_valid( def _step(self, states: DiscreteStates, actions: Actions) -> States: """Calls the core self._step method of the parent class, and updates masks.""" new_states = super()._step(states, actions) - self.update_masks(new_states) # TODO: update_masks is owned by the env, not the states!! + self.update_masks( + new_states + ) # TODO: update_masks is owned by the env, not the states!! return new_states def get_states_indices( @@ -470,4 +469,3 @@ def terminating_states(self) -> DiscreteStates: return NotImplementedError( "The environment does not support enumeration of states" ) - diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 0656ba64..e7d80921 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -1,6 +1,6 @@ +import math from abc import ABC, abstractmethod from typing import Generic, Tuple, TypeVar, Union -import math import torch import torch.nn as nn diff --git a/src/gfn/gym/box.py b/src/gfn/gym/box.py index 7c070682..22ed18a7 100644 --- a/src/gfn/gym/box.py +++ b/src/gfn/gym/box.py @@ -25,8 +25,12 @@ def __init__( self.delta = delta self.epsilon = epsilon s0 = torch.tensor([0.0, 0.0], device=torch.device(device_str)) - exit_action = torch.tensor([-float("inf"), -float("inf")], device=torch.device(device_str)) - dummy_action = torch.tensor([float("inf"), float("inf")], device=torch.device(device_str)) + exit_action = torch.tensor( + [-float("inf"), -float("inf")], device=torch.device(device_str) + ) + dummy_action = torch.tensor( + [float("inf"), float("inf")], device=torch.device(device_str) + ) self.R0 = R0 self.R1 = R1 @@ -41,8 +45,8 @@ def __init__( ) def make_random_states_tensor( - self, batch_shape: Tuple[int, ...] - ) -> TT["batch_shape", 2, torch.float]: + self, batch_shape: Tuple[int, ...] + ) -> TT["batch_shape", 2, torch.float]: return torch.rand(batch_shape + (2,), device=self.device) def step( diff --git a/src/gfn/gym/discrete_ebm.py b/src/gfn/gym/discrete_ebm.py index 85495f95..644d6cbd 100644 --- a/src/gfn/gym/discrete_ebm.py +++ b/src/gfn/gym/discrete_ebm.py @@ -2,8 +2,8 @@ from typing import Literal, Tuple import torch -from torch import Tensor import torch.nn as nn +from torch import Tensor from torchtyping import TensorType as TT from gfn.actions import Actions @@ -89,7 +89,7 @@ def __init__( super().__init__( s0=s0, - state_shape=(self.ndim, ), + state_shape=(self.ndim,), # dummy_action=, # exit_action=, n_actions=n_actions, diff --git a/src/gfn/states.py b/src/gfn/states.py index 3fd209d4..883765b8 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from math import prod -from typing import ClassVar, Optional, Sequence, cast, Callable +from typing import Callable, ClassVar, Optional, Sequence, cast import torch from torchtyping import TensorType as TT @@ -49,7 +49,11 @@ class States(ABC): sf: ClassVar[ TT["state_shape", torch.float] ] # Dummy state, used to pad a batch of states - make_random_states_tensor: Callable = lambda x: (_ for _ in ()).throw(NotImplementedError("The environment does not support initialization of random states.")) + make_random_states_tensor: Callable = lambda x: (_ for _ in ()).throw( + NotImplementedError( + "The environment does not support initialization of random states." + ) + ) def __init__(self, tensor: TT["batch_shape", "state_shape"]): """Initalize the State container with a batch of states. @@ -267,6 +271,7 @@ class DiscreteStates(States, ABC): forward_masks: A boolean tensor of allowable forward policy actions. backward_masks: A boolean tensor of allowable backward policy actions. """ + n_actions: ClassVar[int] device: ClassVar[torch.device] @@ -276,7 +281,6 @@ def __init__( forward_masks: Optional[TT["batch_shape", "n_actions", torch.bool]] = None, backward_masks: Optional[TT["batch_shape", "n_actions - 1", torch.bool]] = None, ) -> None: - """Initalize a DiscreteStates container with a batch of states and masks. Args: tensor: A batch of states. diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index 18801016..192a5dcb 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -5,8 +5,8 @@ from dataclasses import dataclass -import pytest import numpy as np +import pytest from .train_box import main as train_box_main from .train_discreteebm import main as train_discreteebm_main diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index 0ea3e913..e9ecbeae 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -233,9 +233,7 @@ def main(args): # noqa: C901 print(f"current optimizer LR: {optimizer.param_groups[0]['lr']}") trajectories = gflownet.sample_trajectories( - env, - sample_off_policy=False, - n_samples=args.batch_size + env, sample_off_policy=False, n_samples=args.batch_size ) training_samples = gflownet.to_training_samples(trajectories) diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py index 5fdb2591..3a441648 100644 --- a/tutorials/examples/train_discreteebm.py +++ b/tutorials/examples/train_discreteebm.py @@ -20,11 +20,9 @@ from gfn.gflownet import FMGFlowNet from gfn.gym import DiscreteEBM from gfn.modules import DiscretePolicyEstimator -from gfn.utils.common import validate +from gfn.utils.common import set_seed, validate from gfn.utils.modules import NeuralNet, Tabular -from gfn.utils.common import set_seed - DEFAULT_SEED = 4444 @@ -72,9 +70,7 @@ def main(args): # noqa: C901 validation_info = {"l1_dist": float("inf")} for iteration in trange(n_iterations): trajectories = gflownet.sample_trajectories( - env, - off_policy=False, - n_samples=args.batch_size + env, off_policy=False, n_samples=args.batch_size ) training_samples = gflownet.to_training_samples(trajectories) diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index f8982727..517da98e 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -28,11 +28,9 @@ ) from gfn.gym import HyperGrid from gfn.modules import DiscretePolicyEstimator, ScalarEstimator -from gfn.utils.common import validate +from gfn.utils.common import set_seed, validate from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular -from gfn.utils.common import set_seed - DEFAULT_SEED = 4444 @@ -225,7 +223,9 @@ def main(args): # noqa: C901 n_iterations = args.n_trajectories // args.batch_size validation_info = {"l1_dist": float("inf")} for iteration in trange(n_iterations): - trajectories = gflownet.sample_trajectories(env, n_samples=args.batch_size, sample_off_policy=off_policy_sampling) + trajectories = gflownet.sample_trajectories( + env, n_samples=args.batch_size, sample_off_policy=off_policy_sampling + ) training_samples = gflownet.to_training_samples(trajectories) if replay_buffer is not None: with torch.no_grad(): diff --git a/tutorials/examples/train_line.py b/tutorials/examples/train_line.py index 744e9294..4e69c4ee 100644 --- a/tutorials/examples/train_line.py +++ b/tutorials/examples/train_line.py @@ -7,10 +7,10 @@ from tqdm import trange from gfn.gflownet import TBGFlowNet # TODO: Extend to SubTBGFlowNet +from gfn.gym.line import Line from gfn.modules import GFNModule from gfn.states import States from gfn.utils import NeuralNet -from gfn.gym.line import Line from gfn.utils.common import set_seed @@ -113,7 +113,9 @@ def log_prob(self, sampled_actions): actions_to_eval[~exit_idx] = sampled_actions[~exit_idx] if sum(~exit_idx) > 0: - logprobs[~exit_idx] = self.dist.log_prob(actions_to_eval)[~exit_idx].unsqueeze(-1) + logprobs[~exit_idx] = self.dist.log_prob(actions_to_eval)[ + ~exit_idx + ].unsqueeze(-1) return logprobs.squeeze(-1) @@ -187,6 +189,7 @@ def to_probability_distribution( n_steps=self.n_steps_per_trajectory, ) + def train( gflownet, env, @@ -220,7 +223,6 @@ def train( scale_schedule = np.linspace(exploration_var_starting_val, 0, n_iterations) for iteration in tbar: - optimizer.zero_grad() # Off Policy Sampling. trajectories = gflownet.sample_trajectories( @@ -259,7 +261,6 @@ def train( if __name__ == "__main__": - environment = Line( mus=[2, 5], sigmas=[0.5, 0.5], From 3b756a53ceb1bc2ea15065e579ec747a539e0e7a Mon Sep 17 00:00:00 2001 From: Joseph Date: Wed, 29 Nov 2023 13:40:04 -0500 Subject: [PATCH 17/28] cleanup --- src/gfn/gym/line.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/gfn/gym/line.py b/src/gfn/gym/line.py index 6b246f94..a0b2534e 100644 --- a/src/gfn/gym/line.py +++ b/src/gfn/gym/line.py @@ -42,7 +42,7 @@ def __init__( action_shape=(1,), # [x_pos] dummy_action=dummy_action, exit_action=exit_action, - ) # sf is -inf by defaukt. + ) # sf is -inf by default. def step( self, states: States, actions: Actions @@ -73,11 +73,6 @@ def is_action_valid( def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]: s = final_states.tensor[..., 0] - # return torch.logsumexp(torch.stack([m.log_prob(s) for m in self.mixture], 0), 0) - - # if s.nelement() == 0: - # return torch.zeros(final_states.batch_shape) - log_rewards = torch.empty((len(self.mixture),) + final_states.batch_shape) for i, m in enumerate(self.mixture): log_rewards[i] = m.log_prob(s) From c98f42359245c2e44d4dd0c516e09be399134593 Mon Sep 17 00:00:00 2001 From: jdv Date: Thu, 30 Nov 2023 10:10:22 -0500 Subject: [PATCH 18/28] gradient clipping added back in --- tutorials/examples/train_line.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tutorials/examples/train_line.py b/tutorials/examples/train_line.py index 4e69c4ee..ccbbf1cf 100644 --- a/tutorials/examples/train_line.py +++ b/tutorials/examples/train_line.py @@ -236,11 +236,11 @@ def train( loss.backward() # Gradient Clipping. - # for p in gflownet.parameters(): - # if p.ndim > 0 and p.grad is not None: # We do not clip logZ grad. - # p.grad.data.clamp_( - # -gradient_clip_value, gradient_clip_value - # ).nan_to_num_(0.0) + for p in gflownet.parameters(): + if p.ndim > 0 and p.grad is not None: # We do not clip logZ grad. + p.grad.data.clamp_( + -gradient_clip_value, gradient_clip_value + ).nan_to_num_(0.0) optimizer.step() states_visited += len(trajectories) From 6af395f1c50be1129fe05aab8cc4b1d0ae6fa9d2 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 8 Dec 2023 11:38:09 -0500 Subject: [PATCH 19/28] renaming make_States_class to follow pep --- src/gfn/env.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/gfn/env.py b/src/gfn/env.py index 7d79def5..bd5861ed 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -60,8 +60,8 @@ def __init__( self.dummy_action = dummy_action self.exit_action = exit_action - self.States = self.make_States_class() - self.Actions = self.make_Actions_class() + self.States = self.make_states_class() + self.Actions = self.make_actions_class() if preprocessor is None: assert ( @@ -120,11 +120,11 @@ def make_random_states_tensor(self, batch_shape: Tuple) -> Tensor: raise NotImplementedError # Optionally implemented by the user when advanced functionality is required. - def make_States_class(self) -> type[States]: + def make_states_class(self) -> type[States]: """The default States class factory for all Environments. Returns a class that inherits from States and implements assumed methods. - The make_States_class method should be overwritten to achieve more + The make_states_class method should be overwritten to achieve more environment-specific States functionality. """ env = self @@ -139,11 +139,11 @@ class DefaultEnvState(States): return DefaultEnvState - def make_Actions_class(self) -> type[Actions]: + def make_actions_class(self) -> type[Actions]: """The default Actions class factory for all Environments. Returns a class that inherits from Actions and implements assumed methods. - The make_Actions_class method should be overwritten to achieve more + The make_actions_class method should be overwritten to achieve more environment-specific Actions functionality. """ env = self @@ -288,7 +288,7 @@ def __init__( n_actions: int, s0: TT["state_shape", torch.float], state_shape: Tuple, - # action_shape: Tuple, # TODO: Remove? I feel like we might need this. + action_shape: Tuple = (1,), dummy_action: Optional[TT["action_shape", torch.long]] = None, exit_action: Optional[TT["action_shape", torch.long]] = None, sf: Optional[TT["state_shape", torch.float]] = None, @@ -323,7 +323,7 @@ def __init__( super().__init__( s0, state_shape, - (1,), # The action_shape is always 1. TODO: is it? + action_shape, dummy_action, exit_action, sf, @@ -335,7 +335,7 @@ def __init__( def states_from_tensor(self, tensor: Tensor): """Wraps the supplied Tensor in a States instance & updates masks.""" - states_instance = self.make_States_class()(tensor) + states_instance = self.make_states_class()(tensor) self.update_masks(states_instance) return states_instance @@ -376,7 +376,7 @@ def update_masks(self, states: type[States]) -> None: Called automatically after each step for discrete environments. """ - def make_States_class(self) -> type[States]: + def make_states_class(self) -> type[States]: env = self class DiscreteEnvStates(DiscreteStates): @@ -389,7 +389,7 @@ class DiscreteEnvStates(DiscreteStates): return DiscreteEnvStates - def make_Actions_class(self) -> type[Actions]: + def make_actions_class(self) -> type[Actions]: env = self n_actions = self.n_actions From 6c0d8aa753db8ce594b30a215af96aa2ed8bbfb8 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 8 Dec 2023 11:38:49 -0500 Subject: [PATCH 20/28] updated documentation --- tutorials/ENV.md | 83 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 74 insertions(+), 9 deletions(-) diff --git a/tutorials/ENV.md b/tutorials/ENV.md index 67f4a5ea..c4797698 100644 --- a/tutorials/ENV.md +++ b/tutorials/ENV.md @@ -1,21 +1,86 @@ # Creating `torchgfn` environments -To define an environment, the user needs to define the tensor `s0` representing the initial state $s_0$, from which the `state_shape` attribute is inferred, and optionally a tensor representing the sink state $s_f$, which is only used for padding incomplete trajectories. If it is not specified, `sf` is set to a tensor of the same shape as `s0` filled with $-\infty$. +To define an environment, the user needs to define the tensor `s0` representing the initial state $s_0$, and optionally a tensor representing the sink state $s_f$, which denotes the end of a trajectory (and can be used for padding). If it is not specified, `sf` is set to a tensor of the same shape as `s0` filled with $-\infty$. -If the environment is discrete, in which case it is an instance of `DiscreteEnv`, the total number of actions should be specified as an attribute. +The user must also define the `action_shape`, which may or may not be of +different dimensionality to the `state_shape`. For example, in many environments +a timestamp exists as part of the state to prevent cycles, and actions cannot +(directly) modify this value. -If the states (as represented in the `States` class) need to be transformed to another format before being processed (by neural networks for example), then the environment should define a `preprocessor` attribute, which should be an instance of the [base preprocessor class](https://github.com/saleml/torchgfn/tree/master/src/gfn/preprocessors.py). If no preprocessor is defined, the states are used as is (actually transformed using the `IdentityPreprocessor`, which transforms the state tensors to `FloatTensor`s). Implementing a specific preprocessor requires defining the `preprocess` function, and the `output_shape` attribute, which is a tuple representing the shape of *one* preprocessed state. +A `dummy_action` and `exit_action` tensor must also be submitted by the user. +The `exit_action` is a unique action that brings the state to $s_f$. The +`dummy_action` should be different from the `exit_action` (and not be a valid +trajectory) action - it's used to pad batched action tensors (after the +exit action). This is useful when trajectories will be of different lengths +within the batch. -The user needs to implement the following two abstract functions: -- The method `make_States_class` that creates the corresponding subclass of [`States`](https://github.com/saleml/torchgfn/tree/master/src/gfn/states.py). For discrete environments, the resulting class should be a subclass of [`DiscreteStates`](https://github.com/saleml/torchgfn/tree/master/src/gfn/states.py), that implements the `update_masks` method specifying which actions are available at each state. -- The method `make_Actions_class` that creates a subclass of [`Actions`](https://github.com/saleml/torchgfn/tree/master/src/gfn/actions.py), simply by specifying the required class variables (the shape of an action tensor, the dummy action, and the exit action). This method is implemented by default for all `DiscreteEnv`s. +In addition, a number of methods must be defined by the user: ++ `env.step(self, states, actions)` accepts a batch of states and actions, and + returns a batch of `next_states``. This is used for forward trajectories. ++ `env.backward_step(self, states, actions)` accepts a batch of `next_states` + and actions and returns a batch of `states`. This is used for backward + trajectories. + + These functions do not need to handle masking for discrete + environments, nor checking whether actions are allowed, nor checking + whether a state is the sink state, etc... These checks are handled in + `Env._step` and `Env._backward_step` functions, that are not implemented + by the user. ++ `env.is_action_valid(self, states, actions, backward)`: This function is used + to ensure all actions are valid for both forward and backward trajectores + (these are often different sets of rules) for continuous environments. It + accepts a batch of states and actions, and returning `True` only if all + actions can be taken at the given states. ++ `env.make_random_states_tensor(self, batch_shape)` is an **optional** method + which is consumed by the States class automatically, which is useful if you + want random samples you can evaluate under your reward model or policy. ++ `env.reset(self, ...)` can also **optionally** be overwritten by the user + to support custom logic. For example, for conditional GFlowNets, the + conditioning tensor can be concatenated to $s_0$ automatically here. ++ `env.log_reward(self, final_states)` must be defined, which calculates the + log reward of the terminating states (i.e. state with all $s_f$ as a child in + the DAG). It by default returns the log of `env.reward(self, final_states)`, + which is not implemented. The user can decide to either implement the `reward` + method, or if it is simpler / more numerically stable, to override the + `log_reward` method and leave the `reward` unimplemented. -The logic of the environment is handled by the methods `maskless_step` and `maskless_backward_step`, that need to be implemented, which specify how an action changes a state (going forward and backward). These functions do not need to handle masking for discrete environments, nor checking whether actions are allowed, nor checking whether a state is the sink state, etc... These checks are handled in `Env._step` and `Env._backward_step` functions, that are not implemented by the user. Non-discrete environments must implement the `is_action_valid` function, taking a batch of states and actions, and returning `True` only if all actions can be taken at the given states. -- The `log_reward` function that assigns the logarithm of a nonnegative reward to every terminating state (i.e. state with all $s_f$ as a child in the DAG). If `log_reward` is not implemented, `reward` needs to be. +If the environment is discrete, it is an instance of `DiscreteEnv`, and +therefore total number of actions should be specified as an attribute. The +`action_shape` is assumed to be `(1,)`, as the common use case of a +`DiscreteEnv` would be to sample a single action per step. However, this can be +set to any shape by the user (for example `(1,5)` if the policy is sampling 5 +independent actions per step). +If the states (as represented in the `States` class) need to be transformed to another format before being processed (by neural networks for example), then the environment should define a `preprocessor` attribute, which should be an instance of the [base preprocessor class](https://github.com/gfnorg/torchgfn/tree/master/src/gfn/preprocessors.py). If no preprocessor is defined, the states are used as is (actually transformed using the `IdentityPreprocessor`, which transforms the state tensors to `FloatTensor`s). Implementing a specific preprocessor requires defining the `preprocess` function, and the `output_shape` attribute, which is a tuple representing the shape of *one* preprocessed state. -For `DiscreteEnv`s, the user can define a`get_states_indices` method that assigns a unique integer number to each state, and a `n_states` property that returns an integer representing the number of states (excluding $s_f$) in the environment. The function `get_terminating_states_indices` can also be implemented and serves the purpose of uniquely identifying terminating states of the environment, which is useful for [tabular `GFNModule`s](https://github.com/saleml/torchgfn/tree/master/src/gfn/utils/modules.py). Other properties and functions can be implemented as well, such as the `log_partition` or the `true_dist_pmf` properties. +In addition to the above methods, in the discrete case, you must also define +the following method: + ++ `env.update_masks(self, states)`: in discrete environments, the `States` class + contains state-dependent forward and backward masks, which define allowable + forward and backward actions conditioned on the state. Note that in + calculating these masks, the user can leverage the helper methods + `DiscreteStates.set_nonexit_action_masks`, + `DiscreteStates.set_exit_masks`, and + `DiscreteStates.init_forward_masks`. + +The code automatically implements the following two class factories, which the +majority of users will not need to overwrite. However, the user could override +these factories to imbue new functionality into the `States` and `Actions` that +interact with the environment: +- The method `make_states_class` that creates the corresponding subclass of [`States`](https://github.com/saleml/torchgfn/tree/master/src/gfn/states.py). For discrete environments, the resulting class should be a subclass of [`DiscreteStates`](https://github.com/saleml/torchgfn/tree/master/src/gfn/states.py), that implements the `update_masks` method specifying which actions are available at each state. +- The method `make_actions_class` that creates a subclass of [`Actions`](https://github.com/saleml/torchgfn/tree/master/src/gfn/actions.py), simply by specifying the required class variables (the shape of an action tensor, the dummy action, and the exit action). This method is implemented by default for all `DiscreteEnv`s. + +The logic of the environment is handled by the methods `step` and `backward_step`, that need to be implemented, which specify how an action changes a state (going forward and backward). + +For `DiscreteEnv`s, the user can define a `get_states_indices` method that +assigns a unique integer number to each state, and a `n_states` property that +returns an integer representing the number of states (excluding $s_f$) in the environment. The function `get_terminating_states_indices` can also be +implemented and serves the purpose of uniquely identifying terminating states of +the environment, which is useful for +[tabular `GFNModule`s](https://github.com/saleml/torchgfn/tree/master/src/gfn/utils/modules.py). +Other properties and functions can be implemented as well, such as the +`log_partition` or the `true_dist_pmf` properties. For reference, it might be useful to look at one of the following provided environments: - [HyperGrid](https://github.com/saleml/torchgfn/tree/master/src/gfn/gym/hypergrid.py) is an example of a discrete environment where all states are terminating states. - [DiscreteEBM](https://github.com/saleml/torchgfn/tree/master/src/gfn/gym/discrete_ebm.py) is an example of a discrete environment where all trajectories are of the same length but only some states are terminating. From 364e52dab3b5ed80b09796924f24a806a8d14e34 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 8 Dec 2023 11:39:09 -0500 Subject: [PATCH 21/28] rename methods --- tutorials/notebooks/intro_gfn_continuous_line.ipynb | 4 ++-- tutorials/notebooks/intro_gfn_smiley.ipynb | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tutorials/notebooks/intro_gfn_continuous_line.ipynb b/tutorials/notebooks/intro_gfn_continuous_line.ipynb index 086f5c33..232abda1 100644 --- a/tutorials/notebooks/intro_gfn_continuous_line.ipynb +++ b/tutorials/notebooks/intro_gfn_continuous_line.ipynb @@ -135,7 +135,7 @@ " sf = torch.FloatTensor([float(\"inf\"), float(\"inf\")], ).to(s0.device)\n", " super().__init__(s0=s0, sf=sf) # Overwriting the default sf of -inf.\n", "\n", - " def make_States_class(self) -> type[States]:\n", + " def make_states_class(self) -> type[States]:\n", " env = self\n", "\n", " class LineStates(States):\n", @@ -153,7 +153,7 @@ "\n", " return LineStates\n", "\n", - " def make_Actions_class(self) -> type[Actions]:\n", + " def make_actions_class(self) -> type[Actions]:\n", " env = self\n", "\n", " class LineActions(Actions):\n", diff --git a/tutorials/notebooks/intro_gfn_smiley.ipynb b/tutorials/notebooks/intro_gfn_smiley.ipynb index 8853b00b..7552ac9a 100644 --- a/tutorials/notebooks/intro_gfn_smiley.ipynb +++ b/tutorials/notebooks/intro_gfn_smiley.ipynb @@ -1624,7 +1624,7 @@ " preprocessor=IdentityPreprocessor(output_dim=state_dim)\n", " )\n", "\n", - " def make_States_class(self) -> type[DiscreteStates]:\n", + " def make_states_class(self) -> type[DiscreteStates]:\n", " \"Creates a States class for this environment\"\n", " env = self\n", "\n", @@ -1917,7 +1917,7 @@ "source": [ "# Note that here, we have the log edge flows, so we take the sum(exp(log_flows)) to\n", "# calculate the partition function estimate.\n", - "s_0 = env.make_States_class()(torch.zeros(6))\n", + "s_0 = env.make_states_class()(torch.zeros(6))\n", "print(\"Partition function estimate Z={:.2f}\".format(\n", " sum(torch.exp(estimator(s_0)[:6])) # logsumexp.\n", " )\n", From 8d8a4c14b568968954103a44bed43543c2806a0a Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 8 Dec 2023 11:40:23 -0500 Subject: [PATCH 22/28] rename method --- src/gfn/gym/helpers/test_box_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gfn/gym/helpers/test_box_utils.py b/src/gfn/gym/helpers/test_box_utils.py index fb004140..e0ac56dc 100644 --- a/src/gfn/gym/helpers/test_box_utils.py +++ b/src/gfn/gym/helpers/test_box_utils.py @@ -30,7 +30,7 @@ def test_mixed_distributions(n_components: int, n_components_s0: int): R2=2.0, device_str="cpu", ) - States = environment.make_States_class() + States = environment.make_states_class() # Three cases: when all states are s0, some are s0, and none are s0. centers_mixed = States(torch.FloatTensor([[0.03, 0.06], [0.0, 0.0], [0.0, 0.0]])) From ebf0db2e6a991072e8191e59885ec425d6813ccc Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 13 Feb 2024 00:24:50 -0500 Subject: [PATCH 23/28] deps --- pyproject.toml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 83730722..957a60ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,9 +33,9 @@ myst-parser = { version = "*", optional = true } pre-commit = { version = "*", optional = true } pytest = { version = "*", optional = true } renku-sphinx-theme = { version = "*", optional = true } -sphinx = { version = "*", optional = true } +sphinx = { version = ">=6.2.1", optional = true } sphinx_rtd_theme = { version = "*", optional = true } -sphinx-autoapi = { version = "*", optional = true } +sphinx-autoapi = { version = ">=3.0.0", optional = true } sphinx-math-dollar = { version = "*", optional = true } tox = { version = "*", optional = true } @@ -85,8 +85,6 @@ all = [ "Homepage" = "https://gfn.readthedocs.io/en/latest/" "Bug Tracker" = "https://github.com/saleml/gfn/issues" - - [tool.black] py36 = true include = '\.pyi?$' From 71e66033280967575b045657a8d5b8e7aaee2daf Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 13 Feb 2024 00:27:49 -0500 Subject: [PATCH 24/28] requirements --- docs/requirements_docs.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/requirements_docs.txt b/docs/requirements_docs.txt index ce2ebb44..1a812be6 100644 --- a/docs/requirements_docs.txt +++ b/docs/requirements_docs.txt @@ -1,9 +1,9 @@ pre-commit black pytest -sphinx==5.3.0 +sphinx>=6.2.1 myst-parser==0.18.1 sphinx_rtd_theme==1.1.1 sphinx-math-dollar==1.2.1 -sphinx-autoapi==2.0.0 +sphinx-autoapi>=3.0.0 renku-sphinx-theme \ No newline at end of file From c393014ed0da12350f6b08e335601547625468bc Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 13 Feb 2024 00:33:35 -0500 Subject: [PATCH 25/28] deps --- docs/requirements_docs.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/requirements_docs.txt b/docs/requirements_docs.txt index 1a812be6..45501fd3 100644 --- a/docs/requirements_docs.txt +++ b/docs/requirements_docs.txt @@ -2,8 +2,8 @@ pre-commit black pytest sphinx>=6.2.1 -myst-parser==0.18.1 -sphinx_rtd_theme==1.1.1 -sphinx-math-dollar==1.2.1 +myst-parser +sphinx_rtd_theme +sphinx-math-dollar sphinx-autoapi>=3.0.0 renku-sphinx-theme \ No newline at end of file From b85f1ebbebfb46d5c98ca6163dda89f402ef7c54 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 16 Feb 2024 13:53:30 -0500 Subject: [PATCH 26/28] merged --- src/gfn/containers/trajectories.py | 5 +++++ src/gfn/env.py | 4 +--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 251b5921..ff8e7852 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -16,6 +16,11 @@ from gfn.containers.transitions import Transitions +def is_tensor(t) -> bool: + """Checks whether t is a torch.Tensor instance.""" + return isinstance(t, Tensor) + + # TODO: remove env from this class? class Trajectories(Container): """Container for complete trajectories (starting in $s_0$ and ending in $s_f$). diff --git a/src/gfn/env.py b/src/gfn/env.py index 588b2553..9b045ca3 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -218,9 +218,7 @@ def _step( not_done_states = new_states[~new_sink_states_idx] not_done_actions = actions[~new_sink_states_idx] - new_not_done_states_tensor = self.maskless_step( - not_done_states, not_done_actions - ) + new_not_done_states_tensor = self.step(not_done_states, not_done_actions) new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor From ad80e7e6c34c3c7769263079f33e71a7a898b14b Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 16 Feb 2024 13:57:39 -0500 Subject: [PATCH 27/28] update --- tutorials/examples/train_line.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tutorials/examples/train_line.py b/tutorials/examples/train_line.py index db03f684..ccbbf1cf 100644 --- a/tutorials/examples/train_line.py +++ b/tutorials/examples/train_line.py @@ -1,8 +1,3 @@ -<<<<<<< HEAD -======= -from typing import ClassVar, Literal, Tuple - ->>>>>>> eedc7e8735f6dc4aafc748f149cbae24b47d9062 import matplotlib.pyplot as plt import numpy as np import torch From ae3fa2e1c5f43187dd75b824a8339e7ac6b6b13c Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 16 Feb 2024 13:57:53 -0500 Subject: [PATCH 28/28] update --- tutorials/examples/train_discreteebm.py | 5 ----- tutorials/examples/train_hypergrid.py | 5 ----- 2 files changed, 10 deletions(-) diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py index 45428423..3574fa2d 100644 --- a/tutorials/examples/train_discreteebm.py +++ b/tutorials/examples/train_discreteebm.py @@ -19,14 +19,9 @@ from gfn.gflownet import FMGFlowNet from gfn.gym import DiscreteEBM from gfn.modules import DiscretePolicyEstimator -<<<<<<< HEAD -from gfn.utils.common import set_seed, validate -from gfn.utils.modules import NeuralNet, Tabular -======= from gfn.utils.common import set_seed from gfn.utils.modules import NeuralNet, Tabular from gfn.utils.training import validate ->>>>>>> eedc7e8735f6dc4aafc748f149cbae24b47d9062 DEFAULT_SEED = 4444 diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index f4548af0..4d4e3a25 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -27,14 +27,9 @@ ) from gfn.gym import HyperGrid from gfn.modules import DiscretePolicyEstimator, ScalarEstimator -<<<<<<< HEAD -from gfn.utils.common import set_seed, validate -from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular -======= from gfn.utils.common import set_seed from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular from gfn.utils.training import validate ->>>>>>> eedc7e8735f6dc4aafc748f149cbae24b47d9062 DEFAULT_SEED = 4444