diff --git a/docs/requirements_docs.txt b/docs/requirements_docs.txt index ce2ebb44..45501fd3 100644 --- a/docs/requirements_docs.txt +++ b/docs/requirements_docs.txt @@ -1,9 +1,9 @@ pre-commit black pytest -sphinx==5.3.0 -myst-parser==0.18.1 -sphinx_rtd_theme==1.1.1 -sphinx-math-dollar==1.2.1 -sphinx-autoapi==2.0.0 +sphinx>=6.2.1 +myst-parser +sphinx_rtd_theme +sphinx-math-dollar +sphinx-autoapi>=3.0.0 renku-sphinx-theme \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 83730722..957a60ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,9 +33,9 @@ myst-parser = { version = "*", optional = true } pre-commit = { version = "*", optional = true } pytest = { version = "*", optional = true } renku-sphinx-theme = { version = "*", optional = true } -sphinx = { version = "*", optional = true } +sphinx = { version = ">=6.2.1", optional = true } sphinx_rtd_theme = { version = "*", optional = true } -sphinx-autoapi = { version = "*", optional = true } +sphinx-autoapi = { version = ">=3.0.0", optional = true } sphinx-math-dollar = { version = "*", optional = true } tox = { version = "*", optional = true } @@ -85,8 +85,6 @@ all = [ "Homepage" = "https://gfn.readthedocs.io/en/latest/" "Bug Tracker" = "https://github.com/saleml/gfn/issues" - - [tool.black] py36 = true include = '\.pyi?$' diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index 179c3ad0..eebf24db 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -46,8 +46,8 @@ def __init__( self.training_objects = Transitions(env) self.objects_type = "transitions" elif objects_type == "states": - self.training_objects = env.States.from_batch_shape((0,)) - self.terminating_states = env.States.from_batch_shape((0,)) + self.training_objects = env.states_from_batch_shape((0,)) + self.terminating_states = env.states_from_batch_shape((0,)) self.objects_type = "states" else: raise ValueError(f"Unknown objects_type: {objects_type}") diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 3ca3b47e..ff8e7852 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -79,13 +79,11 @@ def __init__( self.states = ( states if states is not None - else env.States.from_batch_shape(batch_shape=(0, 0)) + else env.states_from_batch_shape((0, 0)) ) assert len(self.states.batch_shape) == 2 self.actions = ( - actions - if actions is not None - else env.Actions.make_dummy_actions(batch_shape=(0, 0)) + actions if actions is not None else env.actions_from_batch_shape((0, 0)) ) assert len(self.actions.batch_shape) == 2 self.when_is_done = ( @@ -253,9 +251,13 @@ def extend(self, other: Trajectories) -> None: # Either set, or append, estimator outputs if they exist in the submitted # trajectory. - if self.estimator_outputs is None and is_tensor(other.estimator_outputs): + if self.estimator_outputs is None and isinstance( + other.estimator_outputs, Tensor + ): self.estimator_outputs = other.estimator_outputs - elif is_tensor(self.estimator_outputs) and is_tensor(other.estimator_outputs): + elif isinstance(self.estimator_outputs, Tensor) and isinstance( + other.estimator_outputs, Tensor + ): batch_shape = self.actions.batch_shape n_bs = len(batch_shape) output_dtype = self.estimator_outputs.dtype diff --git a/src/gfn/containers/transitions.py b/src/gfn/containers/transitions.py index 3019f276..4b15f05e 100644 --- a/src/gfn/containers/transitions.py +++ b/src/gfn/containers/transitions.py @@ -68,14 +68,12 @@ def __init__( self.states = ( states if states is not None - else env.States.from_batch_shape(batch_shape=(0,)) + else env.states_from_batch_shape(batch_shape=(0,)) ) assert len(self.states.batch_shape) == 1 self.actions = ( - actions - if actions is not None - else env.Actions.make_dummy_actions(batch_shape=(0,)) + actions if actions is not None else env.actions_from_batch_shape((0,)) ) self.is_done = ( is_done @@ -85,7 +83,7 @@ def __init__( self.next_states = ( next_states if next_states is not None - else env.States.from_batch_shape(batch_shape=(0,)) + else env.states_from_batch_shape(batch_shape=(0,)) ) assert ( len(self.next_states.batch_shape) == 1 diff --git a/src/gfn/env.py b/src/gfn/env.py index bf2a3d3b..9b045ca3 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -2,6 +2,7 @@ from typing import Optional, Tuple, Union import torch +from torch import Tensor from torchtyping import TensorType as TT from gfn.actions import Actions @@ -13,6 +14,10 @@ NonValidActionsError = type("NonValidActionsError", (ValueError,), {}) +def get_device(device_str, default_device): + return torch.device(device_str) if device_str is not None else default_device + + class Env(ABC): """Base class for all environments. Environments require that individual states be represented as a unique tensor of arbitrary shape.""" @@ -20,6 +25,10 @@ class Env(ABC): def __init__( self, s0: TT["state_shape", torch.float], + state_shape: Tuple, + action_shape: Tuple, + dummy_action: Tensor, + exit_action: Tensor, sf: Optional[TT["state_shape", torch.float]] = None, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, @@ -29,6 +38,10 @@ def __init__( Args: s0: Representation of the initial state. All individual states would be of the same shape. + state_shape: + action_shape: + dummy_action: + exit_action: sf: Representation of the final state. Only used for a human readable representation of the states or trajectories. device_str: 'cpu' or 'cuda'. Defaults to None, in which case the device is @@ -37,15 +50,19 @@ def __init__( that can be fed into a neural network. Defaults to None, in which case the IdentityPreprocessor is used. """ - self.device = torch.device(device_str) if device_str is not None else s0.device + self.device = get_device(device_str, default_device=s0.device) self.s0 = s0.to(self.device) if sf is None: sf = torch.full(s0.shape, -float("inf")).to(self.device) self.sf = sf + self.state_shape = state_shape + self.action_shape = action_shape + self.dummy_action = dummy_action + self.exit_action = exit_action - self.States = self.make_States_class() - self.Actions = self.make_Actions_class() + self.States = self.make_states_class() + self.Actions = self.make_actions_class() if preprocessor is None: assert ( @@ -57,14 +74,89 @@ def __init__( self.preprocessor = preprocessor self.is_discrete = False + def states_from_tensor(self, tensor: Tensor): + """Wraps the supplied Tensor in a States instance.""" + return self.States(tensor) + + def states_from_batch_shape(self, batch_shape: Tuple): + """Returns a batch of s0 states with a given batch_shape.""" + return self.States.from_batch_shape(batch_shape) + + def actions_from_tensor(self, tensor: Tensor): + """Wraps the supplied Tensor an an Actions instance.""" + return self.Actions(tensor) + + def actions_from_batch_shape(self, batch_shape: Tuple): + """Returns a batch of dummy actions with the supplied batch_shape.""" + return self.Actions.make_dummy_actions(batch_shape) + + # To be implemented by the User. @abstractmethod - def make_States_class(self) -> type[States]: - """Returns a class that inherits from States and implements the environment-specific methods.""" + def step( + self, states: States, actions: Actions + ) -> TT["batch_shape", "state_shape", torch.float]: + """Function that takes a batch of states and actions and returns a batch of next + states. Does not need to check whether the actions are valid or the states are sink states. + """ @abstractmethod - def make_Actions_class(self) -> type[Actions]: - """Returns a class that inherits from Actions and implements the environment-specific methods.""" + def backward_step( # TODO: rename to backward_step, other method becomes _backward_step. + self, states: States, actions: Actions + ) -> TT["batch_shape", "state_shape", torch.float]: + """Function that takes a batch of states and actions and returns a batch of previous + states. Does not need to check whether the actions are valid or the states are sink states. + """ + + @abstractmethod + def is_action_valid( + self, + states: States, + actions: Actions, + backward: bool = False, + ) -> bool: + """Returns True if the actions are valid in the given states.""" + + def make_random_states_tensor(self, batch_shape: Tuple) -> Tensor: + """Optional method inherited by all States instances to emit a random tensor.""" + raise NotImplementedError + # Optionally implemented by the user when advanced functionality is required. + def make_states_class(self) -> type[States]: + """The default States class factory for all Environments. + + Returns a class that inherits from States and implements assumed methods. + The make_states_class method should be overwritten to achieve more + environment-specific States functionality. + """ + env = self + + class DefaultEnvState(States): + """Defines a States class for this environment.""" + + state_shape = env.state_shape + s0 = env.s0 + sf = env.sf + make_random_states_tensor = env.make_random_states_tensor + + return DefaultEnvState + + def make_actions_class(self) -> type[Actions]: + """The default Actions class factory for all Environments. + + Returns a class that inherits from Actions and implements assumed methods. + The make_actions_class method should be overwritten to achieve more + environment-specific Actions functionality. + """ + env = self + + class DefaultEnvAction(Actions): + action_shape = env.action_shape + dummy_action = env.dummy_action + exit_action = env.exit_action + + return DefaultEnvAction + + # In some cases overwritten by the user to support specific use-cases. def reset( self, batch_shape: Optional[Union[int, Tuple[int]]] = None, @@ -90,31 +182,6 @@ def reset( batch_shape=batch_shape, random=random, sink=sink ) - @abstractmethod - def maskless_step( # TODO: rename to step, other method becomes _step. - self, states: States, actions: Actions - ) -> TT["batch_shape", "state_shape", torch.float]: - """Function that takes a batch of states and actions and returns a batch of next - states. Does not need to check whether the actions are valid or the states are sink states. - """ - - @abstractmethod - def maskless_backward_step( # TODO: rename to backward_step, other method becomes _backward_step. - self, states: States, actions: Actions - ) -> TT["batch_shape", "state_shape", torch.float]: - """Function that takes a batch of states and actions and returns a batch of previous - states. Does not need to check whether the actions are valid or the states are sink states. - """ - - @abstractmethod - def is_action_valid( - self, - states: States, - actions: Actions, - backward: bool = False, - ) -> bool: - """Returns True if the actions are valid in the given states.""" - def validate_actions( self, states: States, actions: Actions, backward: bool = False ) -> bool: @@ -124,13 +191,16 @@ def validate_actions( assert states.batch_shape == actions.batch_shape return self.is_action_valid(states, actions, backward) - def step( + def _step( self, states: States, actions: Actions, ) -> States: - """Function that takes a batch of states and actions and returns a batch of next - states and a boolean tensor indicating sink states in the new batch.""" + """Core step function. Calls the user-defined self.step() function. + + Function that takes a batch of states and actions and returns a batch of next + states and a boolean tensor indicating sink states in the new batch. + """ new_states = states.clone() # TODO: Ensure this is efficient! valid_states_idx: TT["batch_shape", torch.bool] = ~states.is_sink_state valid_actions = actions[valid_states_idx] @@ -148,21 +218,22 @@ def step( not_done_states = new_states[~new_sink_states_idx] not_done_actions = actions[~new_sink_states_idx] - new_not_done_states_tensor = self.maskless_step( - not_done_states, not_done_actions - ) + new_not_done_states_tensor = self.step(not_done_states, not_done_actions) new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor return new_states - def backward_step( + def _backward_step( self, states: States, actions: Actions, ) -> States: - """Function that takes a batch of states and actions and returns a batch of next - states and a boolean tensor indicating initial states in the new batch.""" + """Core backward_step function. Calls the user-defined self.backward_step fn. + + This function takes a batch of states and actions and returns a batch of next + states and a boolean tensor indicating initial states in the new batch. + """ new_states = states.clone() # TODO: Ensure this is efficient! valid_states_idx: TT["batch_shape", torch.bool] = ~new_states.is_initial_state valid_actions = actions[valid_states_idx] @@ -174,13 +245,11 @@ def backward_step( ) # Calculate the backward step, and update only the states which are not Done. - new_not_done_states_tensor = self.maskless_backward_step( - valid_states, valid_actions - ) + new_not_done_states_tensor = self.backward_step(valid_states, valid_actions) new_states.tensor[valid_states_idx] = new_not_done_states_tensor if isinstance(new_states, DiscreteStates): - new_states.update_masks() + self.update_masks(new_states) return new_states @@ -216,6 +285,10 @@ def __init__( self, n_actions: int, s0: TT["state_shape", torch.float], + state_shape: Tuple, + action_shape: Tuple = (1,), + dummy_action: Optional[TT["action_shape", torch.long]] = None, + exit_action: Optional[TT["action_shape", torch.long]] = None, sf: Optional[TT["state_shape", torch.float]] = None, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, @@ -225,22 +298,103 @@ def __init__( Args: n_actions: The number of actions in the environment. s0: The initial state tensor (shared among all trajectories). + state_shape: + action_shape: ? + dummy_action: The value of the dummy (padding) action. + exit_action: The value of the exit action. sf: The final state tensor (shared among all trajectories). device_str: String representation of a torch.device. preprocessor: An optional preprocessor for intermediate states. """ - self.n_actions = n_actions - super().__init__(s0, sf, device_str, preprocessor) - self.is_discrete = True + device = get_device(device_str, default_device=s0.device) + + # The default dummy action is -1. + if isinstance(dummy_action, type(None)): + dummy_action = torch.tensor([-1], device=device) + + # The default exit action index is the final element of the action space. + if isinstance(exit_action, type(None)): + exit_action = torch.tensor([n_actions - 1], device=device) + + self.n_actions = n_actions # Before init, for compatibility with States. + + super().__init__( + s0, + state_shape, + action_shape, + dummy_action, + exit_action, + sf, + device_str, + preprocessor, + ) + + self.is_discrete = True # After init, else it will be overwritten. + + def states_from_tensor(self, tensor: Tensor): + """Wraps the supplied Tensor in a States instance & updates masks.""" + states_instance = self.make_states_class()(tensor) + self.update_masks(states_instance) + return states_instance + + # In some cases overwritten by the user to support specific use-cases. + def reset( + self, + batch_shape: Optional[Union[int, Tuple[int]]] = None, + random: bool = False, + sink: bool = False, + seed: int = None, + ) -> States: + """Instantiates a batch of initial states. + + `random` and `sink` cannot be both True. When `random` is `True` and `seed` is + not `None`, environment randomization is fixed by the submitted seed for + reproducibility. + """ + assert not (random and sink) + + if random and seed is not None: + torch.manual_seed(seed) # TODO: Improve seeding here? + + if batch_shape is None: + batch_shape = (1,) + if isinstance(batch_shape, int): + batch_shape = (batch_shape,) + states = self.States.from_batch_shape( + batch_shape=batch_shape, random=random, sink=sink + ) + self.update_masks(states) - def make_Actions_class(self) -> type[Actions]: + return states + + @abstractmethod + def update_masks(self, states: type[States]) -> None: + """Updates the masks in States. + + Called automatically after each step for discrete environments. + """ + + def make_states_class(self) -> type[States]: + env = self + + class DiscreteEnvStates(DiscreteStates): + state_shape = env.state_shape + s0 = env.s0 + sf = env.sf + make_random_states_tensor = env.make_random_states_tensor + n_actions = env.n_actions + device = env.device + + return DiscreteEnvStates + + def make_actions_class(self) -> type[Actions]: env = self n_actions = self.n_actions class DiscreteEnvActions(Actions): - action_shape = (1,) - dummy_action = torch.tensor([-1], device=env.device) # Double check - exit_action = torch.tensor([n_actions - 1], device=env.device) + action_shape = env.action_shape + dummy_action = env.dummy_action.to(device=env.device) + exit_action = env.exit_action.to(device=env.device) return DiscreteEnvActions @@ -251,13 +405,12 @@ def is_action_valid( masks_tensor = states.backward_masks if backward else states.forward_masks return torch.gather(masks_tensor, 1, actions.tensor).all() - def step( - self, - states: DiscreteStates, - actions: Actions, - ) -> States: - new_states = super().step(states, actions) - new_states.update_masks() + def _step(self, states: DiscreteStates, actions: Actions) -> States: + """Calls the core self._step method of the parent class, and updates masks.""" + new_states = super()._step(states, actions) + self.update_masks( + new_states + ) # TODO: update_masks is owned by the env, not the states!! return new_states def get_states_indices( diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index a2acae34..061761d4 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -89,9 +89,9 @@ def flow_matching_loss( backward_actions = torch.full_like( valid_backward_states.backward_masks[:, 0], action_idx, dtype=torch.long ).unsqueeze(-1) - backward_actions = env.Actions(backward_actions) + backward_actions = env.actions_from_tensor(backward_actions) - valid_backward_states_parents = env.backward_step( + valid_backward_states_parents = env._backward_step( valid_backward_states, backward_actions ) diff --git a/src/gfn/gym/box.py b/src/gfn/gym/box.py index d5a899bd..22ed18a7 100644 --- a/src/gfn/gym/box.py +++ b/src/gfn/gym/box.py @@ -1,5 +1,5 @@ from math import log -from typing import ClassVar, Literal, Tuple +from typing import Literal, Tuple import torch from torchtyping import TensorType as TT @@ -25,49 +25,36 @@ def __init__( self.delta = delta self.epsilon = epsilon s0 = torch.tensor([0.0, 0.0], device=torch.device(device_str)) + exit_action = torch.tensor( + [-float("inf"), -float("inf")], device=torch.device(device_str) + ) + dummy_action = torch.tensor( + [float("inf"), float("inf")], device=torch.device(device_str) + ) self.R0 = R0 self.R1 = R1 self.R2 = R2 - super().__init__(s0=s0) - - def make_States_class(self) -> type[States]: - env = self - - class BoxStates(States): - state_shape: ClassVar[Tuple[int, ...]] = (2,) - s0 = env.s0 - sf = env.sf # should be (-inf, -inf) - - @classmethod - def make_random_states_tensor( - cls, batch_shape: Tuple[int, ...] - ) -> TT["batch_shape", 2, torch.float]: - return torch.rand(batch_shape + (2,), device=env.device) - - return BoxStates - - def make_Actions_class(self) -> type[Actions]: - env = self - - class BoxActions(Actions): - action_shape: ClassVar[Tuple[int, ...]] = (2,) - dummy_action: ClassVar[TT[2]] = torch.tensor( - [float("inf"), float("inf")], device=env.device - ) - exit_action: ClassVar[TT[2]] = torch.tensor( - [-float("inf"), -float("inf")], device=env.device - ) + super().__init__( + s0=s0, + state_shape=(2,), # () + action_shape=(2,), + dummy_action=dummy_action, + exit_action=exit_action, + ) - return BoxActions + def make_random_states_tensor( + self, batch_shape: Tuple[int, ...] + ) -> TT["batch_shape", 2, torch.float]: + return torch.rand(batch_shape + (2,), device=self.device) - def maskless_step( + def step( self, states: States, actions: Actions ) -> TT["batch_shape", 2, torch.float]: return states.tensor + actions.tensor - def maskless_backward_step( + def backward_step( self, states: States, actions: Actions ) -> TT["batch_shape", 2, torch.float]: return states.tensor - actions.tensor diff --git a/src/gfn/gym/discrete_ebm.py b/src/gfn/gym/discrete_ebm.py index a4f82735..644d6cbd 100644 --- a/src/gfn/gym/discrete_ebm.py +++ b/src/gfn/gym/discrete_ebm.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod -from typing import ClassVar, Literal, Tuple +from typing import Literal, Tuple import torch import torch.nn as nn +from torch import Tensor from torchtyping import TensorType as TT from gfn.actions import Actions @@ -87,49 +88,37 @@ def __init__( raise ValueError(f"Unknown preprocessor {preprocessor_name}") super().__init__( - n_actions=n_actions, s0=s0, + state_shape=(self.ndim,), + # dummy_action=, + # exit_action=, + n_actions=n_actions, sf=sf, device_str=device_str, preprocessor=preprocessor, ) - def make_States_class(self) -> type[DiscreteStates]: - env = self - - class DiscreteEBMStates(DiscreteStates): - state_shape: ClassVar[tuple[int, ...]] = (env.ndim,) - s0 = env.s0 - sf = env.sf - n_actions = env.n_actions - device = env.device - - @classmethod - def make_random_states_tensor( - cls, batch_shape: Tuple[int, ...] - ) -> TT["batch_shape", "state_shape", torch.float]: - return torch.randint( - -1, - 2, - batch_shape + (env.ndim,), - dtype=torch.long, - device=env.device, - ) - - def update_masks(self) -> None: - self.set_default_typing() - self.forward_masks[..., : env.ndim] = self.tensor == -1 - self.forward_masks[..., env.ndim : 2 * env.ndim] = self.tensor == -1 - self.forward_masks[..., -1] = torch.all(self.tensor != -1, dim=-1) - self.backward_masks[..., : env.ndim] = self.tensor == 0 - self.backward_masks[..., env.ndim : 2 * env.ndim] = self.tensor == 1 - - return DiscreteEBMStates + def update_masks(self, states: type[States]) -> None: + states.set_default_typing() + states.forward_masks[..., : self.ndim] = states.tensor == -1 + states.forward_masks[..., self.ndim : 2 * self.ndim] = states.tensor == -1 + states.forward_masks[..., -1] = torch.all(states.tensor != -1, dim=-1) + states.backward_masks[..., : self.ndim] = states.tensor == 0 + states.backward_masks[..., self.ndim : 2 * self.ndim] = states.tensor == 1 + + def make_random_states_tensor(self, batch_shape: Tuple) -> Tensor: + return torch.randint( + -1, + 2, + batch_shape + (self.ndim,), + dtype=torch.long, + device=self.device, + ) def is_exit_actions(self, actions: TT["batch_shape"]) -> TT["batch_shape"]: return actions == self.n_actions - 1 - def maskless_step( + def step( self, states: States, actions: Actions ) -> TT["batch_shape", "state_shape", torch.float]: # First, we select that actions that replace a -1 with a 0. @@ -149,15 +138,18 @@ def maskless_step( ) return states.tensor - def maskless_backward_step( + def backward_step( self, states: States, actions: Actions ) -> TT["batch_shape", "state_shape", torch.float]: - # In this env, states are n-dim vectors. s0 is empty (represented as -1), - # so s0=[-1, -1, ..., -1], each action is replacing a -1 with either a - # 0 or 1. Action i in [0, ndim-1] os replacing s[i] with 0, whereas - # action i in [ndim, 2*ndim-1] corresponds to replacing s[i - ndim] with 1. - # A backward action asks "what index should be set back to -1", hence the fmod - # to enable wrapping of indices. + """Performs a backward step. + + In this env, states are n-dim vectors. s0 is empty (represented as -1), + so s0=[-1, -1, ..., -1], each action is replacing a -1 with either a + 0 or 1. Action i in [0, ndim-1] os replacing s[i] with 0, whereas + action i in [ndim, 2*ndim-1] corresponds to replacing s[i - ndim] with 1. + A backward action asks "what index should be set back to -1", hence the fmod + to enable wrapping of indices. + """ return states.tensor.scatter(-1, actions.tensor.fmod(self.ndim), -1) def reward(self, final_states: DiscreteStates) -> TT["batch_shape"]: diff --git a/src/gfn/gym/helpers/test_box_utils.py b/src/gfn/gym/helpers/test_box_utils.py index fb004140..e0ac56dc 100644 --- a/src/gfn/gym/helpers/test_box_utils.py +++ b/src/gfn/gym/helpers/test_box_utils.py @@ -30,7 +30,7 @@ def test_mixed_distributions(n_components: int, n_components_s0: int): R2=2.0, device_str="cpu", ) - States = environment.make_States_class() + States = environment.make_states_class() # Three cases: when all states are s0, some are s0, and none are s0. centers_mixed = States(torch.FloatTensor([[0.03, 0.06], [0.0, 0.0], [0.0, 0.0]])) diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index 71d2862e..b8bf27d1 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -1,7 +1,7 @@ """ Copied and Adapted from https://github.com/Tikquuss/GflowNets_Tutorial """ -from typing import ClassVar, Literal, Tuple +from typing import Literal, Tuple import torch from einops import rearrange @@ -53,7 +53,6 @@ def __init__( sf = torch.full( (ndim,), fill_value=-1, dtype=torch.long, device=torch.device(device_str) ) - n_actions = ndim + 1 if preprocessor_name == "Identity": @@ -74,55 +73,43 @@ def __init__( else: raise ValueError(f"Unknown preprocessor {preprocessor_name}") + state_shape = (self.ndim,) + super().__init__( n_actions=n_actions, s0=s0, + state_shape=state_shape, sf=sf, device_str=device_str, preprocessor=preprocessor, ) - def make_States_class(self) -> type[DiscreteStates]: - "Creates a States class for this environment" - env = self - - class HyperGridStates(DiscreteStates): - state_shape: ClassVar[tuple[int, ...]] = (env.ndim,) - s0 = env.s0 - sf = env.sf - n_actions = env.n_actions - device = env.device - - @classmethod - def make_random_states_tensor( - cls, batch_shape: Tuple[int, ...] - ) -> TT["batch_shape", "state_shape", torch.float]: - "Creates a batch of random states." - states_tensor = torch.randint( - 0, env.height, batch_shape + env.s0.shape, device=env.device - ) - return states_tensor - - def update_masks(self) -> None: - "Update the masks based on the current states." - self.set_default_typing() - # Not allowed to take any action beyond the environment height, but - # allow early termination. - self.set_nonexit_action_masks( - self.tensor == env.height - 1, - allow_exit=True, - ) - self.backward_masks = self.tensor != 0 - - return HyperGridStates - - def maskless_step( + def update_masks(self, states: type[DiscreteStates]) -> None: + """Update the masks based on the current states.""" + states.set_default_typing() + # Not allowed to take any action beyond the environment height, but + # allow early termination. + states.set_nonexit_action_masks( + states.tensor == self.height - 1, + allow_exit=True, + ) + states.backward_masks = states.tensor != 0 + + def make_random_states_tensor( + self, batch_shape: Tuple[int, ...] + ) -> TT["batch_shape", "state_shape", torch.float]: + """Creates a batch of random states.""" + return torch.randint( + 0, self.height, batch_shape + self.s0.shape, device=self.device + ) + + def step( self, states: DiscreteStates, actions: Actions ) -> TT["batch_shape", "state_shape", torch.float]: new_states_tensor = states.tensor.scatter(-1, actions.tensor, 1, reduce="add") return new_states_tensor - def maskless_backward_step( + def backward_step( self, states: DiscreteStates, actions: Actions ) -> TT["batch_shape", "state_shape", torch.float]: new_states_tensor = states.tensor.scatter(-1, actions.tensor, -1, reduce="add") diff --git a/src/gfn/gym/line.py b/src/gfn/gym/line.py new file mode 100644 index 00000000..a0b2534e --- /dev/null +++ b/src/gfn/gym/line.py @@ -0,0 +1,85 @@ +from typing import Literal + +import torch +from torch.distributions import Normal # TODO: extend to Beta +from torchtyping import TensorType as TT + +from gfn.actions import Actions +from gfn.env import Env +from gfn.states import States + + +class Line(Env): + """Mixture of Gaussians Line environment.""" + + def __init__( + self, + mus: list, + sigmas: list, + init_value: float, + n_sd: float = 4.5, + n_steps_per_trajectory: int = 5, + device_str: Literal["cpu", "cuda"] = "cpu", + ): + assert len(mus) == len(sigmas) + self.mus = torch.tensor(mus) + self.sigmas = torch.tensor(sigmas) + self.n_sd = n_sd + self.n_steps_per_trajectory = n_steps_per_trajectory + self.mixture = [Normal(m, s) for m, s in zip(self.mus, self.sigmas)] + + self.init_value = init_value # Used in s0. + self.lb = min(self.mus) - self.n_sd * max(self.sigmas) # Convienience only. + self.ub = max(self.mus) + self.n_sd * max(self.sigmas) # Convienience only. + assert self.lb < self.init_value < self.ub + + s0 = torch.tensor([self.init_value, 0.0], device=torch.device(device_str)) + dummy_action = torch.tensor([float("inf")], device=torch.device(device_str)) + exit_action = torch.tensor([-float("inf")], device=torch.device(device_str)) + super().__init__( + s0=s0, + state_shape=(2,), # [x_pos, step_counter]. + action_shape=(1,), # [x_pos] + dummy_action=dummy_action, + exit_action=exit_action, + ) # sf is -inf by default. + + def step( + self, states: States, actions: Actions + ) -> TT["batch_shape", 2, torch.float]: + states.tensor[..., 0] = states.tensor[..., 0] + actions.tensor.squeeze( + -1 + ) # x position. + states.tensor[..., 1] = states.tensor[..., 1] + 1 # Step counter. + return states.tensor + + def backward_step( + self, states: States, actions: Actions + ) -> TT["batch_shape", 2, torch.float]: + states.tensor[..., 0] = states.tensor[..., 0] - actions.tensor.squeeze( + -1 + ) # x position. + states.tensor[..., 1] = states.tensor[..., 1] - 1 # Step counter. + return states.tensor + + def is_action_valid( + self, states: States, actions: Actions, backward: bool = False + ) -> bool: + # Can't take a backward step at the beginning of a trajectory. + if torch.any(states[~actions.is_exit].is_initial_state) and backward: + return False + + return True + + def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]: + s = final_states.tensor[..., 0] + log_rewards = torch.empty((len(self.mixture),) + final_states.batch_shape) + for i, m in enumerate(self.mixture): + log_rewards[i] = m.log_prob(s) + + return torch.logsumexp(log_rewards, 0) + + @property + def log_partition(self) -> float: + """Log Partition log of the number of gaussians.""" + return torch.tensor(len(self.mus)).log() diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 56cd83de..68b052a6 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -78,7 +78,7 @@ def sample_actions( else: log_probs = None - actions = env.Actions(actions) + actions = env.actions_from_tensor(actions) if not save_estimator_outputs: estimator_output = None @@ -156,9 +156,7 @@ def sample_trajectories( all_estimator_outputs = [] while not all(dones): - actions = env.Actions.make_dummy_actions( - batch_shape=(n_trajectories,) - ) # TODO: Why do we need this? + actions = env.actions_from_batch_shape((n_trajectories,)) # Dummy actions. log_probs = torch.full( (n_trajectories,), fill_value=0, dtype=torch.float, device=device ) @@ -186,17 +184,16 @@ def sample_trajectories( all_estimator_outputs.append(estimator_outputs_padded) actions[~dones] = valid_actions - if ( - not skip_logprob_calculaion - ): # When off_policy, actions_log_probs are None. + if not skip_logprob_calculaion: + # When off_policy, actions_log_probs are None. log_probs[~dones] = actions_log_probs trajectories_actions += [actions] trajectories_logprobs += [log_probs] if self.estimator.is_backward: - new_states = env.backward_step(states, actions) + new_states = env._backward_step(states, actions) else: - new_states = env.step(states, actions) + new_states = env._step(states, actions) sink_states_mask = new_states.is_sink_state # Increment the step, determine which trajectories are finisihed, and eval @@ -225,7 +222,7 @@ def sample_trajectories( trajectories_states += [states.tensor] trajectories_states = torch.stack(trajectories_states, dim=0) - trajectories_states = env.States(tensor=trajectories_states) + trajectories_states = env.states_from_tensor(trajectories_states) trajectories_actions = env.Actions.stack(trajectories_actions) trajectories_logprobs = torch.stack(trajectories_logprobs, dim=0) diff --git a/src/gfn/states.py b/src/gfn/states.py index e50b6aea..416e5670 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from copy import deepcopy from math import prod -from typing import ClassVar, Optional, Sequence, cast +from typing import Callable, ClassVar, Optional, Sequence, cast import torch from torchtyping import TensorType as TT @@ -50,6 +50,11 @@ class States(ABC): sf: ClassVar[ TT["state_shape", torch.float] ] # Dummy state, used to pad a batch of states + make_random_states_tensor: Callable = lambda x: (_ for _ in ()).throw( + NotImplementedError( + "The environment does not support initialization of random states." + ) + ) def __init__(self, tensor: TT["batch_shape", "state_shape"]): """Initalize the State container with a batch of states. @@ -102,15 +107,6 @@ def make_initial_states_tensor( assert cls.s0 is not None and state_ndim is not None return cls.s0.repeat(*batch_shape, *((1,) * state_ndim)) - @classmethod - def make_random_states_tensor( - cls, batch_shape: tuple[int] - ) -> TT["batch_shape", "state_shape", torch.float]: - """Makes a tensor with a `batch_shape` of random states, placeholder.""" - raise NotImplementedError( - "The environment does not support initialization of random states." - ) - @classmethod def make_sink_states_tensor( cls, batch_shape: tuple[int] @@ -288,29 +284,31 @@ def __init__( """Initalize a DiscreteStates container with a batch of states and masks. Args: tensor: A batch of states. - forward_masks (optional): Initializes a boolean tensor of allowable forward - policy actions. - backward_masks (optional): Initializes a boolean tensor of allowable backward - policy actions. + forward_masks: Initializes a boolean tensor of allowable forward policy + actions. + backward_masks: Initializes a boolean tensor of allowable backward policy + actions. """ super().__init__(tensor) - if forward_masks is None and backward_masks is None: - self.forward_masks = torch.ones( + # In the usual case, no masks are provided and we produce these defaults. + # Note: this **must** be updated externally by the env. + if isinstance(forward_masks, type(None)): + forward_masks = torch.ones( (*self.batch_shape, self.__class__.n_actions), dtype=torch.bool, device=self.__class__.device, ) - self.backward_masks = torch.ones( + if isinstance(backward_masks, type(None)): + backward_masks = torch.ones( (*self.batch_shape, self.__class__.n_actions - 1), dtype=torch.bool, device=self.__class__.device, ) - self.update_masks() - else: - self.forward_masks = cast(torch.Tensor, forward_masks) - self.backward_masks = cast(torch.Tensor, backward_masks) + # Ensures typecasting is consistent no matter what is submitted to init. + self.forward_masks = cast(torch.Tensor, forward_masks) # TODO: Required? + self.backward_masks = cast(torch.Tensor, backward_masks) # TODO: Required? self.set_default_typing() def clone(self) -> States: @@ -332,10 +330,6 @@ def set_default_typing(self) -> None: self.backward_masks, ) - @abstractmethod - def update_masks(self) -> None: - """Updates the masks, called after each action is taken.""" - def _check_both_forward_backward_masks_exist(self): assert self.forward_masks is not None and self.backward_masks is not None diff --git a/testing/test_environments.py b/testing/test_environments.py index 9dff4d6d..b110baac 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -7,11 +7,6 @@ # Utilities. -def format_actions(a, env): - """Returns a Actions instance from a [batch_size, *action_shape] tensor of actions.""" - return env.Actions(a) - - def format_tensor(list_, discrete=True): """ If discrete, returns a long tensor with a singleton batch dimension from list @@ -89,13 +84,13 @@ def test_HyperGrid_fwd_step_with_preprocessors( failing_actions_list = [2, 0, 1] for actions_list in passing_actions_lists: - actions = format_actions(format_tensor(actions_list), env) - states = env.step(states, actions) + actions = env.actions_from_tensor(format_tensor(actions_list)) + states = env._step(states, actions) # Step 4 fails due an invalid input action. - actions = format_actions(format_tensor(failing_actions_list), env) + actions = env.actions_from_tensor(format_tensor(failing_actions_list)) with pytest.raises(NonValidActionsError): - states = env.step(states, actions) + states = env._step(states, actions) expected_rewards = torch.tensor([0.6, 0.1, 0.6]) assert (torch.round(env.reward(states), decimals=7) == expected_rewards).all() @@ -129,14 +124,14 @@ def test_HyperGrid_bwd_step_with_preprocessors( # All passing actions complete sucessfully. for passing_actions_list in passing_actions_lists: - actions = format_actions(format_tensor(passing_actions_list), env) - states = env.backward_step(states, actions) + actions = env.actions_from_tensor(format_tensor(passing_actions_list)) + states = env._backward_step(states, actions) # Fails due to an invalid input action. states = env.reset(batch_shape=(NDIM, ENV_HEIGHT), random=True, seed=SEED) - failing_actions = format_actions(format_tensor(failing_actions_list), env) + failing_actions = env.actions_from_tensor(format_tensor(failing_actions_list)) with pytest.raises(NonValidActionsError): - states = env.backward_step(states, failing_actions) + states = env._backward_step(states, failing_actions) def test_DiscreteEBM_fwd_step(): @@ -156,18 +151,18 @@ def test_DiscreteEBM_fwd_step(): ] # Only next possible move is [4, 4, 4, 4], for actions_list in passing_actions_lists: - actions = format_actions(format_tensor(actions_list), env) - states = env.step(states, actions) + actions = env.actions_from_tensor(format_tensor(actions_list)) + states = env._step(states, actions) # Step 4 fails due an invalid input action (15 is not possible). - actions = format_actions(format_tensor([4, 15, 4, 4]), env) + actions = env.actions_from_tensor(format_tensor([4, 15, 4, 4])) with pytest.raises(RuntimeError): - states = env.step(states, actions) + states = env._step(states, actions) # Step 5 fails due an invalid input action (1 is possible but not in this state). - actions = format_actions(format_tensor([1, 4, 4, 4]), env) + actions = env.actions_from_tensor(format_tensor([1, 4, 4, 4])) with pytest.raises(NonValidActionsError): - states = env.step(states, actions) + states = env._step(states, actions) expected_rewards = torch.tensor([1, 1, 54.5982, 1]) assert (torch.round(env.reward(states), decimals=4) == expected_rewards).all() @@ -188,15 +183,15 @@ def test_DiscreteEBM_bwd_step(): ] # All passing actions complete sucessfully. for passing_actions_list in passing_actions_lists: - actions = format_actions(format_tensor(passing_actions_list), env) - states = env.backward_step(states, actions) + actions = env.actions_from_tensor(format_tensor(passing_actions_list)) + states = env._backward_step(states, actions) # Fails due to an invalid input action. failing_actions_list = [0, 0, 0] states = env.reset(batch_shape=BATCH_SIZE, random=True, seed=SEED) - failing_actions = format_actions(format_tensor(failing_actions_list), env) + failing_actions = env.actions_from_tensor(format_tensor(failing_actions_list)) with pytest.raises(NonValidActionsError): - states = env.backward_step(states, failing_actions) + states = env._backward_step(states, failing_actions) @pytest.mark.parametrize("delta", [0.1, 0.5, 1.0]) @@ -214,11 +209,9 @@ def test_box_fwd_step(delta: float): ] for failing_actions_list in failing_actions_lists_at_s0: - actions = format_actions( - format_tensor(failing_actions_list, discrete=False), env - ) + actions = env.actions_from_tensor(format_tensor(failing_actions_list, discrete=False)) with pytest.raises(NonValidActionsError): - states = env.step(states, actions) + states = env._step(states, actions) # Trying the step function starting from 3 instances of s_0 A, B = None, None @@ -244,8 +237,8 @@ def test_box_fwd_step(delta: float): actions_tensor[B - A < 0] = torch.tensor([-float("inf"), -float("inf")]) actions_list = actions_tensor.tolist() - actions = format_actions(format_tensor(actions_list, discrete=False), env) - states = env.step(states, actions) + actions = env.actions_from_tensor(format_tensor(actions_list, discrete=False)) + states = env._step(states, actions) states_tensor = states.tensor # The following evaluate the maximum angles of the possible actions diff --git a/tutorials/ENV.md b/tutorials/ENV.md index 3fb30718..c4797698 100644 --- a/tutorials/ENV.md +++ b/tutorials/ENV.md @@ -1,21 +1,86 @@ # Creating `torchgfn` environments -To define an environment, the user needs to define the tensor `s0` representing the initial state $s_0$, from which the `state_shape` attribute is inferred, and optionally a tensor representing the sink state $s_f$, which is only used for padding incomplete trajectories. If it is not specified, `sf` is set to a tensor of the same shape as `s0` filled with $-\infty$. +To define an environment, the user needs to define the tensor `s0` representing the initial state $s_0$, and optionally a tensor representing the sink state $s_f$, which denotes the end of a trajectory (and can be used for padding). If it is not specified, `sf` is set to a tensor of the same shape as `s0` filled with $-\infty$. -If the environment is discrete, in which case it is an instance of `DiscreteEnv`, the total number of actions should be specified as an attribute. +The user must also define the `action_shape`, which may or may not be of +different dimensionality to the `state_shape`. For example, in many environments +a timestamp exists as part of the state to prevent cycles, and actions cannot +(directly) modify this value. -If the states (as represented in the `States` class) need to be transformed to another format before being processed (by neural networks for example), then the environment should define a `preprocessor` attribute, which should be an instance of the [base preprocessor class](https://github.com/saleml/torchgfn/tree/master/src/gfn/preprocessors.py). If no preprocessor is defined, the states are used as is (actually transformed using the `IdentityPreprocessor`, which transforms the state tensors to `FloatTensor`s). Implementing a specific preprocessor requires defining the `preprocess` function, and the `output_shape` attribute, which is a tuple representing the shape of *one* preprocessed state. +A `dummy_action` and `exit_action` tensor must also be submitted by the user. +The `exit_action` is a unique action that brings the state to $s_f$. The +`dummy_action` should be different from the `exit_action` (and not be a valid +trajectory) action - it's used to pad batched action tensors (after the +exit action). This is useful when trajectories will be of different lengths +within the batch. -The user needs to implement the following two abstract functions: -- The method `make_States_class` that creates the corresponding subclass of [`States`](https://github.com/saleml/torchgfn/tree/master/src/gfn/states.py). For discrete environments, the resulting class should be a subclass of [`DiscreteStates`](https://github.com/saleml/torchgfn/tree/master/src/gfn/states.py), that implements the `update_masks` method specifying which actions are available at each state. -- The method `make_Actions_class` that creates a subclass of [`Actions`](https://github.com/saleml/torchgfn/tree/master/src/gfn/actions.py), simply by specifying the required class variables (the shape of an action tensor, the dummy action, and the exit action). This method is implemented by default for all `DiscreteEnv`s. +In addition, a number of methods must be defined by the user: ++ `env.step(self, states, actions)` accepts a batch of states and actions, and + returns a batch of `next_states``. This is used for forward trajectories. ++ `env.backward_step(self, states, actions)` accepts a batch of `next_states` + and actions and returns a batch of `states`. This is used for backward + trajectories. + + These functions do not need to handle masking for discrete + environments, nor checking whether actions are allowed, nor checking + whether a state is the sink state, etc... These checks are handled in + `Env._step` and `Env._backward_step` functions, that are not implemented + by the user. ++ `env.is_action_valid(self, states, actions, backward)`: This function is used + to ensure all actions are valid for both forward and backward trajectores + (these are often different sets of rules) for continuous environments. It + accepts a batch of states and actions, and returning `True` only if all + actions can be taken at the given states. ++ `env.make_random_states_tensor(self, batch_shape)` is an **optional** method + which is consumed by the States class automatically, which is useful if you + want random samples you can evaluate under your reward model or policy. ++ `env.reset(self, ...)` can also **optionally** be overwritten by the user + to support custom logic. For example, for conditional GFlowNets, the + conditioning tensor can be concatenated to $s_0$ automatically here. ++ `env.log_reward(self, final_states)` must be defined, which calculates the + log reward of the terminating states (i.e. state with all $s_f$ as a child in + the DAG). It by default returns the log of `env.reward(self, final_states)`, + which is not implemented. The user can decide to either implement the `reward` + method, or if it is simpler / more numerically stable, to override the + `log_reward` method and leave the `reward` unimplemented. -The logic of the environment is handled by the methods `maskless_step` and `maskless_backward_step`, that need to be implemented, which specify how an action changes a state (going forward and backward). These functions do not need to handle masking for discrete environments, nor checking whether actions are allowed, nor checking whether a state is the sink state, etc... These checks are handled in `Env.step` and `Env.backward_step` functions, that do not need to be implemented. Non discrete environments need to implement the `is_action_valid` function however, taking a batch of states and actions, and returning `True` only if all actions can be taken at the given states. -- The `log_reward` function that assigns the logarithm of a nonnegative reward to every terminating state (i.e. state with all $s_f$ as a child in the DAG). If `log_reward` is not implemented, `reward` needs to be. +If the environment is discrete, it is an instance of `DiscreteEnv`, and +therefore total number of actions should be specified as an attribute. The +`action_shape` is assumed to be `(1,)`, as the common use case of a +`DiscreteEnv` would be to sample a single action per step. However, this can be +set to any shape by the user (for example `(1,5)` if the policy is sampling 5 +independent actions per step). +If the states (as represented in the `States` class) need to be transformed to another format before being processed (by neural networks for example), then the environment should define a `preprocessor` attribute, which should be an instance of the [base preprocessor class](https://github.com/gfnorg/torchgfn/tree/master/src/gfn/preprocessors.py). If no preprocessor is defined, the states are used as is (actually transformed using the `IdentityPreprocessor`, which transforms the state tensors to `FloatTensor`s). Implementing a specific preprocessor requires defining the `preprocess` function, and the `output_shape` attribute, which is a tuple representing the shape of *one* preprocessed state. -For `DiscreteEnv`s, the user can define a`get_states_indices` method that assigns a unique integer number to each state, and a `n_states` property that returns an integer representing the number of states (excluding $s_f$) in the environment. The function `get_terminating_states_indices` can also be implemented and serves the purpose of uniquely identifying terminating states of the environment, which is useful for [tabular `GFNModule`s](https://github.com/saleml/torchgfn/tree/master/src/gfn/utils/modules.py). Other properties and functions can be implemented as well, such as the `log_partition` or the `true_dist_pmf` properties. +In addition to the above methods, in the discrete case, you must also define +the following method: + ++ `env.update_masks(self, states)`: in discrete environments, the `States` class + contains state-dependent forward and backward masks, which define allowable + forward and backward actions conditioned on the state. Note that in + calculating these masks, the user can leverage the helper methods + `DiscreteStates.set_nonexit_action_masks`, + `DiscreteStates.set_exit_masks`, and + `DiscreteStates.init_forward_masks`. + +The code automatically implements the following two class factories, which the +majority of users will not need to overwrite. However, the user could override +these factories to imbue new functionality into the `States` and `Actions` that +interact with the environment: +- The method `make_states_class` that creates the corresponding subclass of [`States`](https://github.com/saleml/torchgfn/tree/master/src/gfn/states.py). For discrete environments, the resulting class should be a subclass of [`DiscreteStates`](https://github.com/saleml/torchgfn/tree/master/src/gfn/states.py), that implements the `update_masks` method specifying which actions are available at each state. +- The method `make_actions_class` that creates a subclass of [`Actions`](https://github.com/saleml/torchgfn/tree/master/src/gfn/actions.py), simply by specifying the required class variables (the shape of an action tensor, the dummy action, and the exit action). This method is implemented by default for all `DiscreteEnv`s. + +The logic of the environment is handled by the methods `step` and `backward_step`, that need to be implemented, which specify how an action changes a state (going forward and backward). + +For `DiscreteEnv`s, the user can define a `get_states_indices` method that +assigns a unique integer number to each state, and a `n_states` property that +returns an integer representing the number of states (excluding $s_f$) in the environment. The function `get_terminating_states_indices` can also be +implemented and serves the purpose of uniquely identifying terminating states of +the environment, which is useful for +[tabular `GFNModule`s](https://github.com/saleml/torchgfn/tree/master/src/gfn/utils/modules.py). +Other properties and functions can be implemented as well, such as the +`log_partition` or the `true_dist_pmf` properties. For reference, it might be useful to look at one of the following provided environments: - [HyperGrid](https://github.com/saleml/torchgfn/tree/master/src/gfn/gym/hypergrid.py) is an example of a discrete environment where all states are terminating states. - [DiscreteEBM](https://github.com/saleml/torchgfn/tree/master/src/gfn/gym/discrete_ebm.py) is an example of a discrete environment where all trajectories are of the same length but only some states are terminating. diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index ae592a97..192a5dcb 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -69,7 +69,7 @@ def test_hypergrid(ndim: int, height: int): args = HypergridArgs(ndim=ndim, height=height, n_trajectories=n_trajectories) final_l1_dist = train_hypergrid_main(args) if ndim == 2 and height == 8: - assert np.isclose(final_l1_dist, 8.78e-4, atol=1e-4) + assert np.isclose(final_l1_dist, 8.78e-4, atol=1e-3) elif ndim == 2 and height == 16: assert np.isclose(final_l1_dist, 4.56e-4, atol=1e-4) elif ndim == 4 and height == 8: diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py index 33aa1cc8..3574fa2d 100644 --- a/tutorials/examples/train_discreteebm.py +++ b/tutorials/examples/train_discreteebm.py @@ -63,7 +63,7 @@ def main(args): # noqa: C901 # 4. Train the gflownet - visited_terminating_states = env.States.from_batch_shape((0,)) + visited_terminating_states = env.states_from_batch_shape((0,)) states_visited = 0 n_iterations = args.n_trajectories // args.batch_size diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index e3301cdd..4d4e3a25 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -217,7 +217,7 @@ def main(args): # noqa: C901 optimizer = torch.optim.Adam(params) - visited_terminating_states = env.States.from_batch_shape((0,)) + visited_terminating_states = env.states_from_batch_shape((0,)) states_visited = 0 n_iterations = args.n_trajectories // args.batch_size diff --git a/tutorials/examples/train_line.py b/tutorials/examples/train_line.py index 645a6f06..ccbbf1cf 100644 --- a/tutorials/examples/train_line.py +++ b/tutorials/examples/train_line.py @@ -1,5 +1,3 @@ -from typing import ClassVar, Literal, Tuple - import matplotlib.pyplot as plt import numpy as np import torch @@ -8,112 +6,14 @@ from torchtyping import TensorType as TT from tqdm import trange -from gfn.actions import Actions -from gfn.env import Env from gfn.gflownet import TBGFlowNet # TODO: Extend to SubTBGFlowNet +from gfn.gym.line import Line from gfn.modules import GFNModule from gfn.states import States from gfn.utils import NeuralNet from gfn.utils.common import set_seed -class Line(Env): - """Mixture of Gaussians Line environment.""" - - def __init__( - self, - mus: list, - sigmas: list, - init_value: float, - n_sd: float = 4.5, - n_steps_per_trajectory: int = 5, - device_str: Literal["cpu", "cuda"] = "cpu", - ): - assert len(mus) == len(sigmas) - self.mus = torch.tensor(mus) - self.sigmas = torch.tensor(sigmas) - self.n_sd = n_sd - self.n_steps_per_trajectory = n_steps_per_trajectory - self.mixture = [Normal(m, s) for m, s in zip(self.mus, self.sigmas)] - - self.init_value = init_value # Used in s0. - self.lb = min(self.mus) - self.n_sd * max(self.sigmas) # Convienience only. - self.ub = max(self.mus) + self.n_sd * max(self.sigmas) # Convienience only. - assert self.lb < self.init_value < self.ub - - s0 = torch.tensor([self.init_value, 0.0], device=torch.device(device_str)) - super().__init__(s0=s0) # sf is -inf. - - def make_States_class(self) -> type[States]: - env = self - - class LineStates(States): - state_shape: ClassVar[Tuple[int, ...]] = (2,) - s0 = env.s0 # should be [init x value, 0]. - sf = env.sf # should be [-inf, -inf]. - - return LineStates - - def make_Actions_class(self) -> type[Actions]: - env = self - - class LineActions(Actions): - action_shape: ClassVar[Tuple[int, ...]] = (1,) # Does not include counter! - dummy_action: ClassVar[TT[2]] = torch.tensor( - [float("inf")], device=env.device - ) - exit_action: ClassVar[TT[2]] = torch.tensor( - [-float("inf")], device=env.device - ) - - return LineActions - - def maskless_step( - self, states: States, actions: Actions - ) -> TT["batch_shape", 2, torch.float]: - states.tensor[..., 0] = states.tensor[..., 0] + actions.tensor.squeeze( - -1 - ) # x position. - states.tensor[..., 1] = states.tensor[..., 1] + 1 # Step counter. - return states.tensor - - def maskless_backward_step( - self, states: States, actions: Actions - ) -> TT["batch_shape", 2, torch.float]: - states.tensor[..., 0] = states.tensor[..., 0] - actions.tensor.squeeze( - -1 - ) # x position. - states.tensor[..., 1] = states.tensor[..., 1] - 1 # Step counter. - return states.tensor - - def is_action_valid( - self, states: States, actions: Actions, backward: bool = False - ) -> bool: - # Can't take a backward step at the beginning of a trajectory. - if torch.any(states[~actions.is_exit].is_initial_state) and backward: - return False - - return True - - def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]: - s = final_states.tensor[..., 0] - # return torch.logsumexp(torch.stack([m.log_prob(s) for m in self.mixture], 0), 0) - - # if s.nelement() == 0: - # return torch.zeros(final_states.batch_shape) - - log_rewards = torch.empty((len(self.mixture),) + final_states.batch_shape) - for i, m in enumerate(self.mixture): - log_rewards[i] = m.log_prob(s) - - return torch.logsumexp(log_rewards, 0) - - @property - def log_partition(self) -> float: - """Log Partition log of the number of gaussians.""" - return torch.tensor(len(self.mus)).log() - - def render(env, validation_samples=None): """Renders the reward distribution over the 1D env.""" x = np.linspace( @@ -123,7 +23,7 @@ def render(env, validation_samples=None): ) # Get the rewards from our environment. - r = env.States( + r = env.states_from_tensor( torch.tensor(np.stack((x, torch.ones(len(x))), 1)) # Add dummy state counter. ) d = torch.exp(env.log_reward(r)) # Plots the reward, not the log reward. @@ -391,7 +291,7 @@ def train( policy_std_max=policy_std_max, ) pb = StepEstimator(environment, pb_module, backward=True) - gflownet = TBGFlowNet(pf=pf, pb=pb, off_policy=False, init_logZ=0.0) + gflownet = TBGFlowNet(pf=pf, pb=pb, off_policy=True, init_logZ=0.0) gflownet = train( gflownet, diff --git a/tutorials/notebooks/intro_gfn_continuous_line.ipynb b/tutorials/notebooks/intro_gfn_continuous_line.ipynb index 086f5c33..232abda1 100644 --- a/tutorials/notebooks/intro_gfn_continuous_line.ipynb +++ b/tutorials/notebooks/intro_gfn_continuous_line.ipynb @@ -135,7 +135,7 @@ " sf = torch.FloatTensor([float(\"inf\"), float(\"inf\")], ).to(s0.device)\n", " super().__init__(s0=s0, sf=sf) # Overwriting the default sf of -inf.\n", "\n", - " def make_States_class(self) -> type[States]:\n", + " def make_states_class(self) -> type[States]:\n", " env = self\n", "\n", " class LineStates(States):\n", @@ -153,7 +153,7 @@ "\n", " return LineStates\n", "\n", - " def make_Actions_class(self) -> type[Actions]:\n", + " def make_actions_class(self) -> type[Actions]:\n", " env = self\n", "\n", " class LineActions(Actions):\n", diff --git a/tutorials/notebooks/intro_gfn_smiley.ipynb b/tutorials/notebooks/intro_gfn_smiley.ipynb index 3ab820b2..7552ac9a 100644 --- a/tutorials/notebooks/intro_gfn_smiley.ipynb +++ b/tutorials/notebooks/intro_gfn_smiley.ipynb @@ -1624,7 +1624,7 @@ " preprocessor=IdentityPreprocessor(output_dim=state_dim)\n", " )\n", "\n", - " def make_States_class(self) -> type[DiscreteStates]:\n", + " def make_states_class(self) -> type[DiscreteStates]:\n", " \"Creates a States class for this environment\"\n", " env = self\n", "\n", @@ -1728,7 +1728,7 @@ "def train(gflownet, optimizer, env, batch_size = 128, n_episodes = 25_000):\n", " \"\"\"Training loop, keeping track of terminal states over training.\"\"\"\n", " # This stores example terminating states.\n", - " visited_terminating_states = env.States.from_batch_shape((0,))\n", + " visited_terminating_states = env.states_from_batch_shape((0,))\n", " states_visited = 0\n", " losses = []\n", "\n", @@ -1917,7 +1917,7 @@ "source": [ "# Note that here, we have the log edge flows, so we take the sum(exp(log_flows)) to\n", "# calculate the partition function estimate.\n", - "s_0 = env.make_States_class()(torch.zeros(6))\n", + "s_0 = env.make_states_class()(torch.zeros(6))\n", "print(\"Partition function estimate Z={:.2f}\".format(\n", " sum(torch.exp(estimator(s_0)[:6])) # logsumexp.\n", " )\n",