diff --git a/pyproject.toml b/pyproject.toml index a0c1a024..0523821a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,6 @@ einops = ">=0.6.1" numpy = ">=1.21.2" python = "^3.10" torch = ">=1.9.0" -torchtyping = ">=0.1.4" # dev dependencies. black = { version = "24.3", optional = true } diff --git a/src/gfn/actions.py b/src/gfn/actions.py index ec0fbe3a..1d579ae7 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -5,7 +5,6 @@ from typing import ClassVar, Sequence import torch -from torchtyping import TensorType as TT class Actions(ABC): @@ -23,23 +22,22 @@ class Actions(ABC): # The following class variable represents the shape of a single action. action_shape: ClassVar[tuple[int, ...]] # All actions need to have the same shape. # The following class variable is padded to shorter trajectories. - dummy_action: ClassVar[TT["action_shape"]] # Dummy action for the environment. + dummy_action: ClassVar[torch.Tensor] # Dummy action for the environment. # The following class variable corresponds to $s \rightarrow s_f$ transitions. - exit_action: ClassVar[TT["action_shape"]] # Action to exit the environment. + exit_action: ClassVar[torch.Tensor] # Action to exit the environment. - def __init__(self, tensor: TT["batch_shape", "action_shape"]): + def __init__(self, tensor: torch.Tensor): """Initialize actions from a tensor. Args: - tensor: tensor of actions + tensor: tensors representing a batch of actions with shape (*batch_shape, *action_shape). """ - self.tensor = tensor - assert len(tensor.shape) >= len(self.action_shape), ( - f"Actions tensor has shape {tensor.shape}, " - f"but the action shape is {self.action_shape}." - # Ensure the tensor has all action dimensions. + assert tensor.shape[-len(self.action_shape):] == self.action_shape, ( + f"Batched actions tensor has shape {tensor.shape}, but the expected action shape is {self.action_shape}." ) - self.batch_shape = tuple(self.tensor.shape)[: -len(self.action_shape)] + + self.tensor = tensor + self.batch_shape = tuple(self.tensor.shape)[:-len(self.action_shape)] @classmethod def make_dummy_actions(cls, batch_shape: tuple[int]) -> Actions: @@ -134,35 +132,38 @@ def extend_with_dummy_actions(self, required_first_dim: int) -> None: "extend_with_dummy_actions is only implemented for bi-dimensional actions." ) - def compare( - self, other: TT["batch_shape", "action_shape"] - ) -> TT["batch_shape", torch.bool]: + def compare(self, other: torch.Tensor) -> torch.Tensor: """Compares the actions to a tensor of actions. Args: - other: tensor of actions + other: tensor of actions to compare, with shape (*batch_shape, *action_shape). + Returns: boolean tensor of shape batch_shape indicating whether the actions are equal. """ + assert other.shape == self.batch_shape + self.action_shape, ( + f"Expected shape {self.batch_shape + self.action_shape}, got {other.shape}." + ) out = self.tensor == other n_batch_dims = len(self.batch_shape) # Flattens all action dims, which we reduce all over. out = out.flatten(start_dim=n_batch_dims).all(dim=-1) + assert out.dtype == torch.bool and out.shape == self.batch_shape return out @property - def is_dummy(self) -> TT["batch_shape", torch.bool]: - """Returns a boolean tensor indicating whether the actions are dummy actions.""" + def is_dummy(self) -> torch.Tensor: + """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are dummy actions.""" dummy_actions_tensor = self.__class__.dummy_action.repeat( *self.batch_shape, *((1,) * len(self.__class__.action_shape)) ) return self.compare(dummy_actions_tensor) @property - def is_exit(self) -> TT["batch_shape", torch.bool]: - """Returns a boolean tensor indicating whether the actions are exit actions.""" + def is_exit(self) -> torch.Tensor: + """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions.""" exit_actions_tensor = self.__class__.exit_action.repeat( *self.batch_shape, *((1,) * len(self.__class__.action_shape)) ) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index d0545d96..5de2d494 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -10,8 +10,6 @@ import numpy as np import torch -from torch import Tensor -from torchtyping import TensorType as TT from gfn.containers.base import Container from gfn.containers.transitions import Transitions @@ -20,7 +18,7 @@ def is_tensor(t) -> bool: """Checks whether t is a torch.Tensor instance.""" - return isinstance(t, Tensor) + return isinstance(t, torch.Tensor) # TODO: remove env from this class? @@ -40,10 +38,10 @@ class Trajectories(Container): env: The environment in which the trajectories are defined. states: The states of the trajectories. actions: The actions of the trajectories. - when_is_done: The time step at which each trajectory ends. + when_is_done: Tensor of shape (n_trajectories,) indicating the time step at which each trajectory ends. is_backward: Whether the trajectories are backward or forward. - log_rewards: The log_rewards of the trajectories. - log_probs: The log probabilities of the trajectories' actions. + log_rewards: Tensor of shape (n_trajectories,) containing the log rewards of the trajectories. + log_probs: Tensor of shape (max_length, n_trajectories) indicating the log probabilities of the trajectories' actions. """ @@ -53,23 +51,24 @@ def __init__( states: States | None = None, conditioning: torch.Tensor | None = None, actions: Actions | None = None, - when_is_done: TT["n_trajectories", torch.long] | None = None, + when_is_done: torch.Tensor | None = None, is_backward: bool = False, - log_rewards: TT["n_trajectories", torch.float] | None = None, - log_probs: TT["max_length", "n_trajectories", torch.float] | None = None, - estimator_outputs: TT["batch_shape", "output_dim", torch.float] | None = None, + log_rewards: torch.Tensor | None = None, + log_probs: torch.Tensor | None = None, + estimator_outputs: torch.Tensor | None = None, ) -> None: """ Args: env: The environment in which the trajectories are defined. states: The states of the trajectories. actions: The actions of the trajectories. - when_is_done: The time step at which each trajectory ends. + when_is_done: Tensor of shape (n_trajectories,) indicating the time step at which each trajectory ends. is_backward: Whether the trajectories are backward or forward. - log_rewards: The log_rewards of the trajectories. - log_probs: The log probabilities of the trajectories' actions. - estimator_outputs: When forward sampling off-policy for an n-step - trajectory, n forward passes will be made on some function approximator, + log_rewards: Tensor of shape (n_trajectories,) containing the log rewards of the trajectories. + log_probs: Tensor of shape (max_length, n_trajectories) indicating the log probabilities of the trajectories' actions. + estimator_outputs: Tensor of shape (batch_shape, output_dim). + When forward sampling off-policy for an n-step trajectory, + n forward passes will be made on some function approximator, which may need to be re-used (for example, for evaluating PF). To avoid duplicated effort, the outputs of the forward passes can be stored here. @@ -93,17 +92,25 @@ def __init__( if when_is_done is not None else torch.full(size=(0,), fill_value=-1, dtype=torch.long) ) + assert self.when_is_done.shape == (self.n_trajectories,) and self.when_is_done.dtype == torch.long + self._log_rewards = ( log_rewards if log_rewards is not None else torch.full(size=(0,), fill_value=0, dtype=torch.float) ) - self.log_probs = ( - log_probs - if log_probs is not None - else torch.full(size=(0, 0), fill_value=0, dtype=torch.float) - ) + assert self._log_rewards.shape == (self.n_trajectories,) and self._log_rewards.dtype == torch.float + + if log_probs is not None: + assert log_probs.shape == (self.max_length, self.n_trajectories) and log_probs.dtype == torch.float + else: + log_probs = torch.full(size=(0, 0), fill_value=0, dtype=torch.float) + self.log_probs = log_probs + self.estimator_outputs = estimator_outputs + if self.estimator_outputs is not None: + # assert self.estimator_outputs.shape[:len(self.states.batch_shape)] == self.states.batch_shape TODO: check why fails + assert self.estimator_outputs.dtype == torch.float def __repr__(self) -> str: states = self.states.tensor.transpose(0, 1) @@ -142,7 +149,8 @@ def last_states(self) -> States: return self.states[self.when_is_done - 1, torch.arange(self.n_trajectories)] @property - def log_rewards(self) -> TT["n_trajectories", torch.float] | None: + def log_rewards(self) -> torch.Tensor | None: + """Returns the log rewards of the trajectories as a tensor of shape (n_trajectories,).""" if self._log_rewards is not None: assert self._log_rewards.shape == (self.n_trajectories,) return self._log_rewards @@ -200,13 +208,23 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories: @staticmethod def extend_log_probs( - log_probs: TT["max_length", "n_trajectories", torch.float], new_max_length: int - ) -> TT["max_max_length", "n_trajectories", torch.float]: - """Extend the log_probs matrix by adding 0 until the required length is reached.""" - if log_probs.shape[0] >= new_max_length: + log_probs: torch.Tensor, new_max_length: int + ) -> torch.Tensor: + """Extend the log_probs matrix by adding 0 until the required length is reached. + + Args: + log_probs: The log_probs tensor of shape (max_length, n_trajectories) to extend. + new_max_length: The new length of the log_probs tensor. + + Returns: The extended log_probs tensor of shape (new_max_length, n_trajectories). + + """ + + max_length, n_trajectories = log_probs.shape + if max_length >= new_max_length: return log_probs else: - return torch.cat( + new_log_probs = torch.cat( ( log_probs, torch.full( @@ -221,6 +239,8 @@ def extend_log_probs( ), dim=0, ) + assert new_log_probs.shape == (new_max_length, n_trajectories) + return new_log_probs def extend(self, other: Trajectories) -> None: """Extend the trajectories with another set of trajectories. @@ -267,11 +287,11 @@ def extend(self, other: Trajectories) -> None: # Either set, or append, estimator outputs if they exist in the submitted # trajectory. if self.estimator_outputs is None and isinstance( - other.estimator_outputs, Tensor + other.estimator_outputs, torch.Tensor ): self.estimator_outputs = other.estimator_outputs - elif isinstance(self.estimator_outputs, Tensor) and isinstance( - other.estimator_outputs, Tensor + elif isinstance(self.estimator_outputs, torch.Tensor) and isinstance( + other.estimator_outputs, torch.Tensor ): batch_shape = self.actions.batch_shape n_bs = len(batch_shape) diff --git a/src/gfn/containers/transitions.py b/src/gfn/containers/transitions.py index 88bffecb..8ae945c9 100644 --- a/src/gfn/containers/transitions.py +++ b/src/gfn/containers/transitions.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Sequence import torch -from torchtyping import TensorType as TT if TYPE_CHECKING: from gfn.actions import Actions @@ -36,11 +35,11 @@ def __init__( states: States | None = None, conditioning: torch.Tensor | None = None, actions: Actions | None = None, - is_done: TT["n_transitions", torch.bool] | None = None, + is_done: torch.Tensor | None = None, next_states: States | None = None, is_backward: bool = False, - log_rewards: TT["n_transitions", torch.float] | None = None, - log_probs: TT["n_transitions", torch.float] | None = None, + log_rewards: torch.Tensor | None = None, + log_probs: torch.Tensor | None = None, ): """Instantiates a container for transitions. @@ -52,14 +51,14 @@ def __init__( states: States object with uni-dimensional `batch_shape`, representing the parents of the transitions. actions: Actions chosen at the parents of each transitions. - is_done: Whether the action is the exit action. + is_done: Tensor of shape (n_transitions,) indicating whether the action is the exit action. next_states: States object with uni-dimensional `batch_shape`, representing the children of the transitions. is_backward: Whether the transitions are backward transitions (i.e. `next_states` is the parent of states). - log_rewards: The log-rewards of the transitions (using a default value like + log_rewards: Tensor of shape (n_transitions,) containing the log-rewards of the transitions (using a default value like `-float('inf')` for non-terminating transitions). - log_probs: The log-probabilities of the actions. + log_probs: Tensor of shape (n_transitions,) containing the log-probabilities of the actions. Raises: AssertionError: If states and next_states do not have matching @@ -78,11 +77,15 @@ def __init__( self.actions = ( actions if actions is not None else env.actions_from_batch_shape((0,)) ) + assert self.actions.batch_shape == self.states.batch_shape + self.is_done = ( is_done if is_done is not None else torch.full(size=(0,), fill_value=False, dtype=torch.bool) ) + assert self.is_done.shape == (self.n_transitions,) and self.is_done.dtype == torch.bool + self.next_states = ( next_states if next_states is not None @@ -93,7 +96,9 @@ def __init__( and self.states.batch_shape == self.next_states.batch_shape ) self._log_rewards = log_rewards if log_rewards is not None else torch.zeros(0) + assert self._log_rewards.shape == (self.n_transitions,) and self._log_rewards.dtype == torch.float self.log_probs = log_probs if log_probs is not None else torch.zeros(0) + assert self.log_probs.shape == (self.n_transitions,) and self.log_probs.dtype == torch.float @property def n_transitions(self) -> int: @@ -124,7 +129,8 @@ def last_states(self) -> States: return self.states[self.is_done] @property - def log_rewards(self) -> TT["n_transitions", torch.float] | None: + def log_rewards(self) -> torch.Tensor | None: + """Compute the tensor of shape (n_transitions,) containing the log rewards for the transitions.""" if self._log_rewards is not None: return self._log_rewards if self.is_backward: @@ -143,13 +149,17 @@ def log_rewards(self) -> TT["n_transitions", torch.float] | None: return log_rewards @property - def all_log_rewards(self) -> TT["n_transitions", 2, torch.float]: + def all_log_rewards(self) -> torch.Tensor: """Calculate all log rewards for the transitions. This is applicable to environments where all states are terminating. This function evaluates the rewards for all transitions that do not end in the sink state. This is useful for the Modified Detailed Balance loss. + Returns: + log_rewards: Tensor of shape (n_transitions, 2) containing the log rewards + for the transitions. + Raises: NotImplementedError: when used for backward transitions. """ @@ -176,6 +186,8 @@ def all_log_rewards(self) -> TT["n_transitions", 2, torch.float]: log_rewards[~is_sink_state, 1] = torch.log( self.env.reward(self.next_states[~is_sink_state]) ) + + assert log_rewards.shape == (self.n_transitions, 2) and log_rewards.dtype == torch.float return log_rewards def __getitem__(self, index: int | Sequence[int]) -> Transitions: diff --git a/src/gfn/env.py b/src/gfn/env.py index c6097243..ae7f2923 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -2,8 +2,6 @@ from typing import Optional, Tuple, Union import torch -from torch import Tensor -from torchtyping import TensorType as TT from gfn.actions import Actions from gfn.preprocessors import IdentityPreprocessor, Preprocessor @@ -24,26 +22,26 @@ class Env(ABC): def __init__( self, - s0: TT["state_shape", torch.float], + s0: torch.Tensor, state_shape: Tuple, action_shape: Tuple, - dummy_action: Tensor, - exit_action: Tensor, - sf: Optional[TT["state_shape", torch.float]] = None, + dummy_action: torch.Tensor, + exit_action: torch.Tensor, + sf: Optional[torch.Tensor] = None, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, ): """Initializes an environment. Args: - s0: Representation of the initial state. All individual states would be of - the same shape. - state_shape: - action_shape: - dummy_action: - exit_action: - sf: Representation of the final state. Only used for a human - readable representation of the states or trajectories. + s0: Tensor of shape "state_shape" representing the initial state. + All individual states would be of the same shape. + state_shape: Tuple representing the shape of the states. + action_shape: Tuple representing the shape of the actions. + dummy_action: Tensor of shape "action_shape" representing a dummy action. + exit_action: Tensor of shape "action_shape" representing the exit action. + sf: Tensor of shape "state_shape" representing the final state. + Only used for a human readable representation of the states or trajectories. device_str: 'cpu' or 'cuda'. Defaults to None, in which case the device is inferred from s0. preprocessor: a Preprocessor object that converts raw states to a tensor @@ -53,8 +51,10 @@ def __init__( self.device = get_device(device_str, default_device=s0.device) self.s0 = s0.to(self.device) + assert s0.shape == state_shape if sf is None: sf = torch.full(s0.shape, -float("inf")).to(self.device) + assert sf.shape == state_shape self.sf = sf self.state_shape = state_shape self.action_shape = action_shape @@ -74,37 +74,79 @@ def __init__( self.preprocessor = preprocessor self.is_discrete = False - def states_from_tensor(self, tensor: Tensor): - """Wraps the supplied Tensor in a States instance.""" + def states_from_tensor(self, tensor: torch.Tensor): + """Wraps the supplied Tensor in a States instance. + + Args: + tensor: The tensor of shape "state_shape" representing the states. + + Returns: + States: An instance of States. + """ return self.States(tensor) def states_from_batch_shape(self, batch_shape: Tuple): - """Returns a batch of s0 states with a given batch_shape.""" + """Returns a batch of s0 states with a given batch_shape. + + Args: + batch_shape: Tuple representing the shape of the batch of states. + + Returns: + States: A batch of initial states. + """ return self.States.from_batch_shape(batch_shape) - def actions_from_tensor(self, tensor: Tensor): - """Wraps the supplied Tensor an an Actions instance.""" + def actions_from_tensor(self, tensor: torch.Tensor): + """Wraps the supplied Tensor an an Actions instance. + + Args: + tensor: The tensor of shape "action_shape" representing the actions. + + Returns: + Actions: An instance of Actions. + """ return self.Actions(tensor) def actions_from_batch_shape(self, batch_shape: Tuple): - """Returns a batch of dummy actions with the supplied batch_shape.""" + """Returns a batch of dummy actions with the supplied batch_shape. + + Args: + batch_shape: Tuple representing the shape of the batch of actions. + + Returns: + Actions: A batch of dummy actions. + """ return self.Actions.make_dummy_actions(batch_shape) # To be implemented by the User. @abstractmethod def step( self, states: States, actions: Actions - ) -> TT["batch_shape", "state_shape", torch.float]: + ) -> torch.Tensor: """Function that takes a batch of states and actions and returns a batch of next states. Does not need to check whether the actions are valid or the states are sink states. + + Args: + states: A batch of states. + actions: A batch of actions. + + Returns: + torch.Tensor: A batch of next states. """ @abstractmethod def backward_step( # TODO: rename to backward_step, other method becomes _backward_step. self, states: States, actions: Actions - ) -> TT["batch_shape", "state_shape", torch.float]: + ) -> torch.Tensor: """Function that takes a batch of states and actions and returns a batch of previous states. Does not need to check whether the actions are valid or the states are sink states. + + Args: + states: A batch of states. + actions: A batch of actions. + + Returns: + torch.Tensor: A batch of previous states. """ @abstractmethod @@ -116,7 +158,7 @@ def is_action_valid( ) -> bool: """Returns True if the actions are valid in the given states.""" - def make_random_states_tensor(self, batch_shape: Tuple) -> Tensor: + def make_random_states_tensor(self, batch_shape: Tuple) -> torch.Tensor: """Optional method inherited by all States instances to emit a random tensor.""" raise NotImplementedError @@ -201,8 +243,11 @@ def _step( Function that takes a batch of states and actions and returns a batch of next states and a boolean tensor indicating sink states in the new batch. """ + assert states.batch_shape == actions.batch_shape new_states = states.clone() # TODO: Ensure this is efficient! - valid_states_idx: TT["batch_shape", torch.bool] = ~states.is_sink_state + valid_states_idx: torch.Tensor = ~states.is_sink_state + assert valid_states_idx.shape == states.batch_shape + assert valid_states_idx.dtype == torch.bool valid_actions = actions[valid_states_idx] valid_states = states[valid_states_idx] @@ -214,6 +259,7 @@ def _step( new_sink_states_idx = actions.is_exit new_states.tensor[new_sink_states_idx] = self.sf new_sink_states_idx = ~valid_states_idx | new_sink_states_idx + assert new_sink_states_idx.shape == states.batch_shape not_done_states = new_states[~new_sink_states_idx] not_done_actions = actions[~new_sink_states_idx] @@ -238,8 +284,11 @@ def _backward_step( This function takes a batch of states and actions and returns a batch of next states and a boolean tensor indicating initial states in the new batch. """ + assert states.batch_shape == actions.batch_shape new_states = states.clone() # TODO: Ensure this is efficient! - valid_states_idx: TT["batch_shape", torch.bool] = ~new_states.is_initial_state + valid_states_idx: torch.Tensor = ~new_states.is_initial_state + assert valid_states_idx.shape == states.batch_shape + assert valid_states_idx.dtype == torch.bool valid_actions = actions[valid_states_idx] valid_states = states[valid_states_idx] @@ -257,15 +306,28 @@ def _backward_step( return new_states - def reward(self, final_states: States) -> TT["batch_shape", torch.float]: + def reward(self, final_states: States) -> torch.Tensor: """The environment's reward given a state. - This or log_reward must be implemented. + + Args: + final_states: A batch of final states. + + Returns: + torch.Tensor: Tensor of shape "batch_shape" containing the rewards. """ raise NotImplementedError("Reward function is not implemented.") - def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]: - """Calculates the log reward.""" + def log_reward(self, final_states: States) -> torch.Tensor: + """Calculates the log reward. + This or reward must be implemented. + + Args: + final_states: A batch of final states. + + Returns: + torch.Tensor: Tensor of shape "batch_shape" containing the log rewards. + """ return torch.log(self.reward(final_states)) @property @@ -288,12 +350,12 @@ class DiscreteEnv(Env, ABC): def __init__( self, n_actions: int, - s0: TT["state_shape", torch.float], + s0: torch.Tensor, state_shape: Tuple, action_shape: Tuple = (1,), - dummy_action: Optional[TT["action_shape", torch.long]] = None, - exit_action: Optional[TT["action_shape", torch.long]] = None, - sf: Optional[TT["state_shape", torch.float]] = None, + dummy_action: Optional[torch.Tensor] = None, + exit_action: Optional[torch.Tensor] = None, + sf: Optional[torch.Tensor] = None, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, ): @@ -301,27 +363,31 @@ def __init__( Args: n_actions: The number of actions in the environment. - s0: The initial state tensor (shared among all trajectories). - state_shape: - action_shape: ? - dummy_action: The value of the dummy (padding) action. - exit_action: The value of the exit action. - sf: The final state tensor (shared among all trajectories). + s0: Tensor of shape "state_shape" representing the initial state (shared among all trajectories). + state_shape: Tuple representing the shape of the states. + action_shape: Tuple representing the shape of the actions. + dummy_action: Optional tensor of shape "action_shape" representing the dummy (padding) action. + exit_action: Optional tensor of shape "action_shape" representing the exit action. + sf: Tensor of shape "state_shape" representing the final state tensor (shared among all trajectories). device_str: String representation of a torch.device. preprocessor: An optional preprocessor for intermediate states. """ device = get_device(device_str, default_device=s0.device) # The default dummy action is -1. - if isinstance(dummy_action, type(None)): + if dummy_action is None: dummy_action = torch.tensor([-1], device=device) # The default exit action index is the final element of the action space. - if isinstance(exit_action, type(None)): + if exit_action is None: exit_action = torch.tensor([n_actions - 1], device=device) - self.n_actions = n_actions # Before init, for compatibility with States. + assert s0.shape == state_shape + assert dummy_action.shape == action_shape + assert exit_action.shape == action_shape + + self.n_actions = n_actions # Before init, for compatibility with States. super().__init__( s0, state_shape, @@ -335,8 +401,15 @@ def __init__( self.is_discrete = True # After init, else it will be overwritten. - def states_from_tensor(self, tensor: Tensor): - """Wraps the supplied Tensor in a States instance & updates masks.""" + def states_from_tensor(self, tensor: torch.Tensor): + """Wraps the supplied Tensor in a States instance & updates masks. + + Args: + tensor: The tensor of shape "state_shape" representing the states. + + Returns: + States: An instance of States. + """ states_instance = self.make_states_class()(tensor) self.update_masks(states_instance) return states_instance @@ -418,14 +491,30 @@ def _step(self, states: DiscreteStates, actions: Actions) -> States: def get_states_indices( self, states: DiscreteStates - ) -> TT["batch_shape", torch.long]: + ) -> torch.Tensor: + """Returns the indices of the states in the environment. + + Args: + states: The batch of states. + + Returns: + torch.Tensor: Tensor of shape "batch_shape" containing the indices of the states. + """ return NotImplementedError( "The environment does not support enumeration of states" ) def get_terminating_states_indices( self, states: DiscreteStates - ) -> TT["batch_shape", torch.long]: + ) -> torch.Tensor: + """Returns the indices of the terminating states in the environment. + + Args: + states: The batch of states. + + Returns: + torch.Tensor: Tensor of shape "batch_shape" containing the indices of the terminating states. + """ return NotImplementedError( "The environment does not support enumeration of states" ) @@ -443,8 +532,8 @@ def n_terminating_states(self) -> int: ) @property - def true_dist_pmf(self) -> TT["n_states", torch.float]: - "Returns a one-dimensional tensor representing the true distribution." + def true_dist_pmf(self) -> torch.Tensor: + "Returns a tensor of shape (n_states,) representing the true distribution." raise NotImplementedError( "The environment does not support enumeration of states" ) diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index b7865a88..8e3fd4b5 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn -from torchtyping import TensorType as TT from gfn.containers import Trajectories from gfn.containers.base import Container @@ -128,10 +127,7 @@ def get_pfs_and_pbs( trajectories: Trajectories, fill_value: float = 0.0, recalculate_all_logprobs: bool = False, - ) -> Tuple[ - TT["max_length", "n_trajectories", torch.float], - TT["max_length", "n_trajectories", torch.float], - ]: + ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Evaluates logprobs for each transition in each trajectory in the batch. More specifically it evaluates $\log P_F (s' \mid s)$ and $\log P_B(s \mid s')$ @@ -246,18 +242,26 @@ def get_pfs_and_pbs( log_pb_trajectories_slice[~valid_actions.is_exit] = valid_log_pb_actions log_pb_trajectories[~trajectories.actions.is_dummy] = log_pb_trajectories_slice + assert log_pf_trajectories.shape == (trajectories.max_length, trajectories.n_trajectories) + assert log_pb_trajectories.shape == (trajectories.max_length, trajectories.n_trajectories) return log_pf_trajectories, log_pb_trajectories def get_trajectories_scores( self, trajectories: Trajectories, recalculate_all_logprobs: bool = False, - ) -> Tuple[ - TT["n_trajectories", torch.float], - TT["n_trajectories", torch.float], - TT["n_trajectories", torch.float], - ]: - """Given a batch of trajectories, calculate forward & backward policy scores.""" + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Given a batch of trajectories, calculate forward & backward policy scores. + + Args: + trajectories: Trajectories to evaluate. + recalculate_all_logprobs: Whether to re-evaluate all logprobs. + + Returns: A tuple of float tensors of shape (n_trajectories,) + containing the total log_pf, total log_pb, and the total + log-likelihood of the trajectories. + + """ log_pf_trajectories, log_pb_trajectories = self.get_pfs_and_pbs( trajectories, recalculate_all_logprobs=recalculate_all_logprobs ) @@ -275,6 +279,9 @@ def get_trajectories_scores( torch.isinf(total_log_pb_trajectories) ): raise ValueError("Infinite logprobs found") + + assert total_log_pf_trajectories.shape == (trajectories.n_trajectories,) + assert total_log_pb_trajectories.shape == (trajectories.n_trajectories,) return ( total_log_pf_trajectories, total_log_pb_trajectories, diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index 2060f7bf..f8ca618e 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -2,7 +2,6 @@ from typing import Tuple import torch -from torchtyping import TensorType as TT from gfn.containers import Trajectories, Transitions from gfn.env import Env @@ -81,11 +80,7 @@ def logF_parameters(self): def get_scores( self, env: Env, transitions: Transitions, recalculate_all_logprobs: bool = False - ) -> Tuple[ - TT["n_transitions", float], - TT["n_transitions", float], - TT["n_transitions", float], - ]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Given a batch of transitions, calculate the scores. Args: @@ -96,6 +91,9 @@ def get_scores( - If transitions have log_probs attribute, use them - this is usually for on-policy learning - Else, re-evaluate the log_probs using the current self.pf - this is usually for off-policy learning with replay buffer + + Returns: A tuple of three tensors of shapes (n_transitions,), representing the + log probabilities of the actions, the log probabilities of the backward actions, and th scores. Raises: ValueError: when supplied with backward transitions. @@ -194,9 +192,12 @@ def get_scores( scores = preds - targets - return (valid_log_pf_actions, log_pb_actions, scores) + assert valid_log_pf_actions.shape == (transitions.n_transitions,) + assert log_pb_actions.shape == (transitions.n_transitions,) + assert scores.shape == (transitions.n_transitions,) + return valid_log_pf_actions, log_pb_actions, scores - def loss(self, env: Env, transitions: Transitions) -> TT[0, float]: + def loss(self, env: Env, transitions: Transitions) -> torch.Tensor: """Detailed balance loss. The detailed balance loss is described in section @@ -223,7 +224,7 @@ class ModifiedDBGFlowNet(PFBasedGFlowNet[Transitions]): def get_scores( self, transitions: Transitions, recalculate_all_logprobs: bool = False - ) -> TT["n_trajectories", torch.float]: + ) -> torch.Tensor: """DAG-GFN-style detailed balance, when all states are connected to the sink. Unless recalculate_all_logprobs=True, in which case we re-evaluate the logprobs of the transitions with @@ -304,7 +305,7 @@ def get_scores( return scores - def loss(self, env: Env, transitions: Transitions) -> TT[0, float]: + def loss(self, env: Env, transitions: Transitions) -> torch.Tensor: """Calculates the modified detailed balance loss.""" scores = self.get_scores(transitions) return torch.mean(scores**2) diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index d9a7c97b..c093ee97 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -1,7 +1,6 @@ from typing import Tuple, Any, Union import torch -from torchtyping import TensorType as TT from gfn.containers import Trajectories from gfn.env import Env @@ -71,7 +70,7 @@ def flow_matching_loss( env: Env, states: DiscreteStates, conditioning: torch.Tensor | None, - ) -> TT["n_trajectories", torch.float]: + ) -> torch.Tensor: """Computes the FM for the provided states. The Flow Matching loss is defined as the log-sum incoming flows minus log-sum @@ -160,7 +159,7 @@ def reward_matching_loss( env: Env, terminating_states: DiscreteStates, conditioning: torch.Tensor, - ) -> TT[0, float]: + ) -> torch.Tensor: """Calculates the reward matching loss from the terminating states.""" del env # Unused assert terminating_states.log_rewards is not None @@ -184,7 +183,7 @@ def loss( Tuple[DiscreteStates, DiscreteStates, torch.Tensor, torch.Tensor], Tuple[DiscreteStates, DiscreteStates, None, None], ], - ) -> TT[0, float]: + ) -> torch.Tensor: """Given a batch of non-terminal and terminal states, compute a loss. Unlike the GFlowNets Foundations paper, we allow more flexibility by passing a diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index 5cbb8b54..b8c41688 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -2,7 +2,6 @@ from typing import List, Literal, Tuple import torch -from torchtyping import TensorType as TT from gfn.containers import Trajectories from gfn.env import Env @@ -14,13 +13,13 @@ ) -ContributionsTensor = TT["max_len * (1 + max_len) / 2", "n_trajectories"] -CumulativeLogProbsTensor = TT["max_length + 1", "n_trajectories"] -LogStateFlowsTensor = TT["max_length", "n_trajectories"] -LogTrajectoriesTensor = TT["max_length", "n_trajectories", torch.float] -MaskTensor = TT["max_length", "n_trajectories"] -PredictionsTensor = TT["max_length + 1 - i", "n_trajectories"] -TargetsTensor = TT["max_length + 1 - i", "n_trajectories"] +ContributionsTensor = torch.Tensor # shape: [max_len * (1 + max_len) / 2, n_trajectories] +CumulativeLogProbsTensor = torch.Tensor # shape: [max_length + 1, n_trajectories] +LogStateFlowsTensor = torch.Tensor # shape: [max_length, n_trajectories] +LogTrajectoriesTensor = torch.Tensor # shape: [max_length, n_trajectories] +MaskTensor = torch.Tensor # shape: [max_length, n_trajectories] +PredictionsTensor = torch.Tensor # shape: [max_length + 1 - i, n_trajectories] +TargetsTensor = torch.Tensor # shape: [max_length + 1 - i, n_trajectories] class SubTBGFlowNet(TrajectoryBasedGFlowNet): @@ -116,7 +115,8 @@ def cumulative_logprobs( trajectories: a batch of trajectories. log_p_trajectories: log probabilities of each transition in each trajectory. - Returns: cumulative sum of log probabilities of each trajectory. + Returns: Tensor of shape (max_length + 1, n_trajectories), containing the + cumulative sum of log probabilities of each trajectory. """ return torch.cat( ( @@ -136,6 +136,13 @@ def calculate_preds( ) -> PredictionsTensor: """ Calculate the predictions tensor for the current sub-trajectory length. + + Args: + log_pf_trajectories_cum: Tensor of shape (max_length + 1, n_trajectories) containing the cumulative log probabilities of the forward actions. + log_state_flows: Tensor of shape (max_length, n_trajectories) containing the log state flows. + i: The sub-trajectory length. + + Returns: The predictions tensor of shape (max_length + 1 - i, n_trajectories). """ current_log_state_flows = ( log_state_flows if i == 1 else log_state_flows[: -(i - 1)] @@ -162,6 +169,18 @@ def calculate_targets( ) -> TargetsTensor: """ Calculate the targets tensor for the current sub-trajectory length. + + Args: + trajectories: The trajectories data. + preds: The predictions tensor of shape (max_length + 1 - i, n_trajectories). + log_pb_trajectories_cum: Tensor of shape (max_length + 1, n_trajectories) containing the cumulative log probabilities of the backward actions. + log_state_flows: Tensor of shape (max_length, n_trajectories) containing the log state flows. + is_terminal_mask: A mask tensor of shape (max_length, n_trajectories) representing terminal states. + sink_states_mask: A mask tensor of shape (max_length, n_trajectories) representing sink states. + full_mask: A mask tensor of shape (max_length, n_trajectories) representing full states. + i: The sub-trajectory length. + + Returns: The targets tensor of shape (max_length + 1 - i, n_trajectories). """ targets = torch.full_like(preds, fill_value=-float("inf")) assert trajectories.log_rewards is not None @@ -199,12 +218,12 @@ def calculate_log_state_flows( Calculate log state flows and masks for sink and terminal states. Args: - trajectories: The trajectories data. env: The environment object. + trajectories: The trajectories data. + log_pf_trajectories: Tensor of shape (max_length, n_trajectories) containing the log forward probabilities of the trajectories. Returns: - log_state_flows: Log state flows. - full_mask: A boolean tensor representing full states. + log_state_flows: Tensor of shape (max_length, n_trajectories) containing the log state flows. """ states = trajectories.states log_state_flows = torch.full_like(log_pf_trajectories, fill_value=-float("inf")) @@ -239,6 +258,12 @@ def calculate_masks( ) -> Tuple[MaskTensor, MaskTensor, MaskTensor]: """ Calculate masks for sink and terminal states. + + Args: + log_state_flows: Tensor of shape (max_length, n_trajectories) containing the log state flows. + trajectories: The trajectories data. + + Returns: a tuple of three mask tensors of shape (max_length, n_trajectories). """ sink_states_mask = log_state_flows == -float("inf") is_terminal_mask = trajectories.actions.is_exit @@ -248,7 +273,7 @@ def calculate_masks( def get_scores( self, env: Env, trajectories: Trajectories - ) -> Tuple[List[TT[0, float]], List[TT[0, float]]]: + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """Scores all submitted trajectories. Returns: @@ -315,15 +340,21 @@ def get_scores( flattening_masks.append(flattening_mask) scores.append(preds - targets) - return (scores, flattening_masks) + return scores, flattening_masks def get_equal_within_contributions( self, trajectories: Trajectories, - all_scores: TT, + all_scores: torch.Tensor, ) -> ContributionsTensor: """ Calculates contributions for the 'equal_within' weighting method. + + Args: + trajectories: The trajectories data. + all_scores: The scores tensor. + + Returns: The contributions tensor of shape (max_len * (1 + max_len) / 2, n_trajectories). """ del all_scores is_done = trajectories.when_is_done @@ -344,10 +375,16 @@ def get_equal_within_contributions( def get_equal_contributions( self, trajectories: Trajectories, - all_scores: TT, + all_scores: torch.Tensor, ) -> ContributionsTensor: """ Calculates contributions for the 'equal' weighting method. + + Args: + trajectories: The trajectories data. + all_scores: The scores tensor. + + Returns: The contributions tensor of shape (max_len * (1 + max_len) / 2, n_trajectories). """ is_done = trajectories.when_is_done max_len = trajectories.max_length @@ -357,10 +394,16 @@ def get_equal_contributions( return contributions def get_tb_contributions( - self, trajectories: Trajectories, all_scores: TT + self, trajectories: Trajectories, all_scores: torch.Tensor ) -> ContributionsTensor: """ Calculates contributions for the 'TB' weighting method. + + Args: + trajectories: The trajectories data. + all_scores: The scores tensor. + + Returns: The contributions tensor of shape (max_len * (1 + max_len) / 2, n_trajectories). """ max_len = trajectories.max_length is_done = trajectories.when_is_done @@ -376,10 +419,16 @@ def get_tb_contributions( def get_modified_db_contributions( self, trajectories: Trajectories, - all_scores: TT, + all_scores: torch.Tensor, ) -> ContributionsTensor: """ Calculates contributions for the 'ModifiedDB' weighting method. + + Args: + trajectories: The trajectories data. + all_scores: The scores tensor. + + Returns: The contributions tensor of shape (max_len * (1 + max_len) / 2, n_trajectories). """ del all_scores is_done = trajectories.when_is_done @@ -404,10 +453,16 @@ def get_modified_db_contributions( def get_geometric_within_contributions( self, trajectories: Trajectories, - all_scores: TT, + all_scores: torch.Tensor, ) -> ContributionsTensor: """ Calculates contributions for the 'geometric_within' weighting method. + + Args: + trajectories: The trajectories data. + all_scores: The scores tensor. + + Returns: The contributions tensor of shape (max_len * (1 + max_len) / 2, n_trajectories). """ del all_scores L = self.lamda @@ -437,7 +492,7 @@ def get_geometric_within_contributions( return contributions - def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]: + def loss(self, env: Env, trajectories: Trajectories) -> torch.Tensor: # Get all scores and masks from the trajectories. scores, flattening_masks = self.get_scores(env, trajectories) flattening_mask = torch.cat(flattening_masks) diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index 94fff80e..b9ee4ae9 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -5,7 +5,6 @@ import torch import torch.nn as nn -from torchtyping import TensorType as TT from gfn.containers import Trajectories from gfn.env import Env @@ -52,7 +51,7 @@ def loss( env: Env, trajectories: Trajectories, recalculate_all_logprobs: bool = False, - ) -> TT[0, float]: + ) -> torch.Tensor: """Trajectory balance loss. The trajectory balance loss is described in 2.3 of @@ -105,7 +104,7 @@ def loss( env: Env, trajectories: Trajectories, recalculate_all_logprobs: bool = False, - ) -> TT[0, float]: + ) -> torch.Tensor: """Log Partition Variance loss. This method is described in section 3.2 of diff --git a/src/gfn/gym/box.py b/src/gfn/gym/box.py index 22ed18a7..5435f629 100644 --- a/src/gfn/gym/box.py +++ b/src/gfn/gym/box.py @@ -2,7 +2,6 @@ from typing import Literal, Tuple import torch -from torchtyping import TensorType as TT from gfn.actions import Actions from gfn.env import Env @@ -46,21 +45,43 @@ def __init__( def make_random_states_tensor( self, batch_shape: Tuple[int, ...] - ) -> TT["batch_shape", 2, torch.float]: + ) -> torch.Tensor: + """Generates random states tensor of shape (*batch_shape, 2).""" return torch.rand(batch_shape + (2,), device=self.device) def step( self, states: States, actions: Actions - ) -> TT["batch_shape", 2, torch.float]: + ) -> torch.Tensor: + """Step function for the Box environment. + + Args: + states: States object representing the current states. + actions: Actions object representing the actions to be taken. + + Returns the next states as tensor of shape (*batch_shape, 2). + """ return states.tensor + actions.tensor def backward_step( self, states: States, actions: Actions - ) -> TT["batch_shape", 2, torch.float]: + ) -> torch.Tensor: + """Backward step function for the Box environment. + + Args: + states: States object representing the current states. + actions: Actions object representing the actions to be taken. + + Returns the previous states as tensor of shape (*batch_shape, 2). + """ return states.tensor - actions.tensor @staticmethod - def norm(x: TT["batch_shape", 2, torch.float]) -> torch.Tensor: + def norm(x: torch.Tensor) -> torch.Tensor: + """Computes the L2 norm of the input tensor along the last dimension. + + Args: + x: Input tensor of shape (*batch_shape, 2). + Returns: normalized tensor of shape `batch_shape`.""" return torch.norm(x, dim=-1) def is_action_valid( @@ -103,14 +124,21 @@ def is_action_valid( return True - def reward(self, final_states: States) -> TT["batch_shape", torch.float]: - """Reward is distance from the goal point.""" + def reward(self, final_states: States) -> torch.Tensor: + """Reward is distance from the goal point. + + Args: + final_states: States object representing the final states. + + Returns the reward tensor of shape `batch_shape`. + """ R0, R1, R2 = (self.R0, self.R1, self.R2) ax = abs(final_states.tensor - 0.5) reward = ( R0 + (0.25 < ax).prod(-1) * R1 + ((0.3 < ax) * (ax < 0.4)).prod(-1) * R2 ) - + + assert reward.shape == final_states.batch_shape return reward @property diff --git a/src/gfn/gym/discrete_ebm.py b/src/gfn/gym/discrete_ebm.py index 644d6cbd..8743db65 100644 --- a/src/gfn/gym/discrete_ebm.py +++ b/src/gfn/gym/discrete_ebm.py @@ -3,8 +3,6 @@ import torch import torch.nn as nn -from torch import Tensor -from torchtyping import TensorType as TT from gfn.actions import Actions from gfn.env import DiscreteEnv @@ -16,24 +14,42 @@ class EnergyFunction(nn.Module, ABC): """Base class for energy functions""" @abstractmethod - def forward( - self, states: TT["batch_shape", "state_shape", torch.float] - ) -> TT["batch_shape"]: + def forward(self, states: torch.Tensor) -> torch.Tensor: + """Forward pass of the energy function + + Args: + states: tensor of states of shape (*batch_shape, *state_shape) + + Returns tensor of energies of shape (*batch_shape) + """ pass class IsingModel(EnergyFunction): """Ising model energy function""" - def __init__(self, J: TT["state_shape", "state_shape", torch.float]): + def __init__(self, J: torch.Tensor): + """Ising model energy function + + Args: + J: interaction matrix of shape (state_shape, state_shape) + """ super().__init__() self.J = J - self.linear = nn.Linear(J.shape[0], 1, bias=False) + self._state_shape, _ = J.shape + assert J.shape == (self._state_shape, self._state_shape) + self.linear = nn.Linear(self._state_shape, 1, bias=False) self.linear.weight.data = J - def forward( - self, states: TT["batch_shape", "state_shape", torch.float] - ) -> TT["batch_shape"]: + def forward(self, states: torch.Tensor) -> torch.Tensor: + """Forward pass of the ising model. + + Args: + states: tensor of states of shape (*batch_shape, *state_shape) + + Returns tensor of energies of shape (*batch_shape) + """ + assert states.shape[-1] == self._state_shape states = states.float() tmp = self.linear(states) return -(states * tmp).sum(-1) @@ -99,14 +115,14 @@ def __init__( ) def update_masks(self, states: type[States]) -> None: - states.set_default_typing() states.forward_masks[..., : self.ndim] = states.tensor == -1 states.forward_masks[..., self.ndim : 2 * self.ndim] = states.tensor == -1 states.forward_masks[..., -1] = torch.all(states.tensor != -1, dim=-1) states.backward_masks[..., : self.ndim] = states.tensor == 0 states.backward_masks[..., self.ndim : 2 * self.ndim] = states.tensor == 1 - def make_random_states_tensor(self, batch_shape: Tuple) -> Tensor: + def make_random_states_tensor(self, batch_shape: Tuple) -> torch.Tensor: + """Generates random states tensor of shape (*batch_shape, ndim).""" return torch.randint( -1, 2, @@ -115,12 +131,25 @@ def make_random_states_tensor(self, batch_shape: Tuple) -> Tensor: device=self.device, ) - def is_exit_actions(self, actions: TT["batch_shape"]) -> TT["batch_shape"]: + def is_exit_actions(self, actions: torch.Tensor) -> torch.Tensor: + """Determines if the actions are exit actions. + + Args: + actions: tensor of actions of shape (*batch_shape, *action_shape) + + Returns tensor of booleans of shape (*batch_shape) + """ return actions == self.n_actions - 1 - def step( - self, states: States, actions: Actions - ) -> TT["batch_shape", "state_shape", torch.float]: + def step(self, states: States, actions: Actions) -> torch.Tensor: + """Performs a step. + + Args: + states: States object representing the current states. + actions: Actions object representing the actions to be taken. + + Returns the next states as tensor of shape (*batch_shape, ndim). + """ # First, we select that actions that replace a -1 with a 0. # Remove singleton dimension for broadcasting. mask_0 = (actions.tensor < self.ndim).squeeze(-1) @@ -138,9 +167,7 @@ def step( ) return states.tensor - def backward_step( - self, states: States, actions: Actions - ) -> TT["batch_shape", "state_shape", torch.float]: + def backward_step(self, states: States, actions: Actions) -> torch.Tensor: """Performs a backward step. In this env, states are n-dim vectors. s0 is empty (represented as -1), @@ -152,33 +179,61 @@ def backward_step( """ return states.tensor.scatter(-1, actions.tensor.fmod(self.ndim), -1) - def reward(self, final_states: DiscreteStates) -> TT["batch_shape"]: + def reward(self, final_states: DiscreteStates) -> torch.Tensor: """Not used during training but provided for completeness. Note the effect of clipping will be seen in these values. + + Args: + final_states: DiscreteStates object representing the final states. + + Returns the reward as tensor of shape (*batch_shape). """ - return torch.exp(self.log_reward(final_states)) + reward = torch.exp(self.log_reward(final_states)) + assert reward.shape == final_states.batch_shape + return reward - def log_reward(self, final_states: DiscreteStates) -> TT["batch_shape"]: - """The energy weighted by alpha is our log reward.""" + def log_reward(self, final_states: DiscreteStates) -> torch.Tensor: + """The energy weighted by alpha is our log reward. + + Args: + final_states: DiscreteStates object representing the final states. + + Returns the log reward as tensor of shape (*batch_shape).""" raw_states = final_states.tensor canonical = 2 * raw_states - 1 log_reward = -self.alpha * self.energy(canonical) + assert log_reward.shape == final_states.batch_shape return log_reward - def get_states_indices(self, states: DiscreteStates) -> TT["batch_shape"]: - """The chosen encoding is the following: -1 -> 0, 0 -> 1, 1 -> 2, then we convert to base 3""" + def get_states_indices(self, states: DiscreteStates) -> torch.Tensor: + """The chosen encoding is the following: -1 -> 0, 0 -> 1, 1 -> 2, then we convert to base 3 + + Args: + states: DiscreteStates object representing the states. + + Returns the states indices as tensor of shape (*batch_shape). + """ states_raw = states.tensor canonical_base = 3 ** torch.arange(self.ndim - 1, -1, -1, device=self.device) - return (states_raw + 1).mul(canonical_base).sum(-1).long() + states_indices = (states_raw + 1).mul(canonical_base).sum(-1).long() + assert states_indices.shape == states.batch_shape + return states_indices - def get_terminating_states_indices( - self, states: DiscreteStates - ) -> TT["batch_shape"]: + def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor: + """Returns the indices of the terminating states. + + Args: + states: DiscreteStates object representing the states. + + Returns the indices of the terminating states as tensor of shape (*batch_shape). + """ states_raw = states.tensor canonical_base = 2 ** torch.arange(self.ndim - 1, -1, -1, device=self.device) - return (states_raw).mul(canonical_base).sum(-1).long() + states_indices = (states_raw).mul(canonical_base).sum(-1).long() + assert states_indices.shape == states.batch_shape + return states_indices @property def n_states(self) -> int: diff --git a/src/gfn/gym/helpers/box_utils.py b/src/gfn/gym/helpers/box_utils.py index 14566be5..3d58fb2e 100644 --- a/src/gfn/gym/helpers/box_utils.py +++ b/src/gfn/gym/helpers/box_utils.py @@ -6,7 +6,6 @@ import torch import torch.nn as nn from torch.distributions import Beta, Categorical, Distribution, MixtureSameFamily -from torchtyping import TensorType as TT from gfn.gym import Box from gfn.modules import GFNModule @@ -35,21 +34,30 @@ def __init__( self, delta: float, northeastern: bool, - centers: TT["n_states", 2], - mixture_logits: TT["n_states", "n_components"], - alpha: TT["n_states", "n_components"], - beta: TT["n_states", "n_components"], + centers: States, # TODO: should probably be a tensor + mixture_logits: torch.Tensor, + alpha: torch.Tensor, + beta: torch.Tensor, ): + """Initializes the distribution. + + Args: + delta: the radius of the quarter disk. + northeastern: whether the quarter disk is northeastern or southwestern. + centers: the centers of the distribution with shape (n_states, 2). + mixture_logits: Tensor of shape (n_states", n_components) containing the logits of the mixture of Beta distributions. + alpha: Tensor of shape (n_states", n_components) containing the alpha parameters of the Beta distributions. + beta: Tensor of shape (n_states", n_components) containing the beta parameters of the Beta distributions. + """ self.delta = delta self.northeastern = northeastern + self.n_states, self.n_components = mixture_logits.shape + + assert centers.tensor.shape == (self.n_states, 2) self.centers = centers - self.n_states = centers.batch_shape[0] - self.n_components = mixture_logits.shape[1] - assert mixture_logits.shape == (self.n_states, self.n_components) assert alpha.shape == (self.n_states, self.n_components) assert beta.shape == (self.n_states, self.n_components) - self.base_dist = MixtureSameFamily( Categorical(logits=mixture_logits), Beta(alpha, beta), @@ -57,7 +65,11 @@ def __init__( self.min_angles, self.max_angles = self.get_min_and_max_angles() - def get_min_and_max_angles(self) -> Tuple[TT["n_states"], TT["n_states"]]: + def get_min_and_max_angles(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes the minimum and maximum angles for the distribution. + + Returns a tuple of two tensors of shape (n_states,) containing the minimum and maximum angles, respectively. + """ if self.northeastern: min_angles = torch.where( self.centers.tensor[..., 0] <= 1 - self.delta, @@ -81,9 +93,18 @@ def get_min_and_max_angles(self) -> Tuple[TT["n_states"], TT["n_states"]]: PI_2_INV * torch.arcsin((self.centers.tensor[..., 1]) / self.delta), ) + assert min_angles.shape == (self.n_states,) + assert max_angles.shape == (self.n_states,) return min_angles, max_angles - def sample(self, sample_shape: torch.Size = torch.Size()) -> TT["sample_shape", 2]: + def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: + """Samples from the distribution. + + Args: + sample_shape: the shape of the samples to generate. + + Returns the sampled actions of shape (sample_shape, n_states, 2). + """ base_01_samples = self.base_dist.sample(sample_shape=sample_shape) sampled_angles = ( @@ -141,9 +162,20 @@ def sample(self, sample_shape: torch.Size = torch.Size()) -> TT["sample_shape", ): raise ValueError("Sampled actions should be of norm delta ish") + assert sampled_actions.shape == sample_shape + (self.n_states, 2) return sampled_actions - def log_prob(self, sampled_actions: TT["batch_size", 2]) -> TT["batch_size"]: + def log_prob(self, sampled_actions: torch.Tensor) -> torch.Tensor: + """Computes the log probability of the sampled actions. + + Args: + sampled_actions: Tensor of shape (*batch_shape, 2) with the actions to compute the log probability of. + + Returns the log probability of the sampled actions as a tensor of shape `batch_shape`. + """ + assert sampled_actions.shape[-1] == 2 + batch_shape = sampled_actions.shape[:-1] + sampled_actions = sampled_actions.to( torch.double ) # Arccos is very brittle, so we use double precision @@ -204,6 +236,7 @@ def log_prob(self, sampled_actions: TT["batch_size", 2]) -> TT["batch_size"]: if torch.any(torch.isinf(logprobs)) or torch.any(torch.isnan(logprobs)): raise ValueError("logprobs contains inf or nan") + assert logprobs.shape == batch_shape return logprobs @@ -219,15 +252,25 @@ class QuarterDisk(Distribution): def __init__( self, delta: float, - mixture_logits: TT["n_components"], - alpha_r: TT["n_components"], - beta_r: TT["n_components"], - alpha_theta: TT["n_components"], - beta_theta: TT["n_components"], + mixture_logits: torch.Tensor, + alpha_r: torch.Tensor, + beta_r: torch.Tensor, + alpha_theta: torch.Tensor, + beta_theta: torch.Tensor, ): + """"Initializes the distribution. + + Args: + delta: the radius of the quarter disk. + mixture_logits: Tensor of shape (n_components,) containing the logits of the mixture of Beta distributions. + alpha_r: Tensor of shape (n_components,) containing the alpha parameters of the Beta distributions for the radius. + beta_r: Tensor of shape (n_components,) containing the beta parameters of the Beta distributions for the radius. + alpha_theta: Tensor of shape (n_components,) containing the alpha parameters of the Beta distributions for the angle. + beta_theta: Tensor of shape (n_components,) containing the beta parameters of the Beta distributions for the angle. + """ self.delta = delta self.mixture_logits = mixture_logits - self.n_components = mixture_logits.shape[0] + (self.n_components,) = mixture_logits.shape assert alpha_r.shape == (self.n_components,) assert beta_r.shape == (self.n_components,) @@ -244,7 +287,14 @@ def __init__( Beta(alpha_theta, beta_theta), ) - def sample(self, sample_shape: torch.Size = torch.Size()) -> TT["sample_shape", 2]: + def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: + """Samples from the distribution. + + Args: + sample_shape: the shape of the samples to generate. + + Returns the sampled actions of shape (sample_shape, 2). + """ base_r_01_samples = self.base_r_dist.sample(sample_shape=sample_shape) base_theta_01_samples = self.base_theta_dist.sample(sample_shape=sample_shape) @@ -258,9 +308,20 @@ def sample(self, sample_shape: torch.Size = torch.Size()) -> TT["sample_shape", ) ) + assert sampled_actions.shape == sample_shape + (2,) return sampled_actions - def log_prob(self, sampled_actions: TT["batch_size", 2]) -> TT["batch_size"]: + def log_prob(self, sampled_actions: torch.Tensor) -> torch.Tensor: + """Computes the log probability of the sampled actions. + + Args: + sampled_actions: Tensor of shape (*batch_shape, 2) with the actions to compute the log probability of. + + Returns the log probability of the sampled actions as a tensor of shape `batch_shape`. + """ + assert sampled_actions.shape[-1] == 2 + batch_shape = sampled_actions.shape[:-1] + sampled_actions = sampled_actions.to( torch.double ) # Arccos is very brittle, so we use double precision @@ -287,6 +348,7 @@ def log_prob(self, sampled_actions: TT["batch_size", 2]) -> TT["batch_size"]: if torch.any(torch.isinf(logprobs)): raise ValueError("logprobs contains inf") + assert logprobs.shape == batch_shape return logprobs @@ -299,13 +361,30 @@ class QuarterCircleWithExit(Distribution): def __init__( self, delta: float, - centers: TT["n_states", 2], - exit_probability: TT["n_states"], - mixture_logits: TT["n_states", "n_components"], - alpha: TT["n_states", "n_components"], - beta: TT["n_states", "n_components"], + centers: States, # TODO: should probably be a tensor + exit_probability: torch.Tensor, + mixture_logits: torch.Tensor, + alpha: torch.Tensor, + beta: torch.Tensor, epsilon: float = 1e-4, ): + """Initializes the distribution. + + Args: + delta: the radius of the quarter disk. + centers: the centers of the distribution with shape (n_states, 2). + exit_probability: Tensor of shape (n_states,) containing the probability of exiting the quarter disk. + mixture_logits: Tensor of shape (n_states, n_components) containing the logits of the mixture of Beta distributions. + alpha: Tensor of shape (n_states, n_components) containing the alpha parameters of the Beta distributions. + beta: Tensor of shape (n_states, n_components) containing the beta parameters of the Beta distributions. + epsilon: the epsilon value to consider the state as being at the border of the square. + """ + self.n_states, n_components = mixture_logits.shape + assert centers.tensor.shape == (self.n_states, 2) + assert exit_probability.shape == (self.n_states,) + assert alpha.shape == (self.n_states, n_components) + assert beta.shape == (self.n_states, n_components) + self.delta = delta self.epsilon = epsilon self.centers = centers @@ -322,7 +401,14 @@ def __init__( centers.device ) - def sample(self, sample_shape=()): + def sample(self, sample_shape=()) -> torch.Tensor: + """Samples from the distribution. + + Args: + sample_shape: the shape of the samples to generate. + + Returns the sampled actions of shape (sample_shape, n_states, 2). + """ actions = self.dist_without_exit.sample(sample_shape) repeated_exit_probability = self.exit_probability.repeat(sample_shape + (1,)) exit_mask = torch.bernoulli(repeated_exit_probability).bool() @@ -340,6 +426,7 @@ def sample(self, sample_shape=()): exit_mask[torch.any(self.centers.tensor >= 1 - self.epsilon, dim=-1)] = True actions[exit_mask] = self.exit_action + assert actions.shape == sample_shape + (self.n_states, 2) return actions def log_prob(self, sampled_actions): @@ -489,9 +576,17 @@ def __init__( # impossible at t=0). self.PFs0 = torch.nn.Parameter(torch.zeros(1, 5 * self.n_components_s0)) - def forward( - self, preprocessed_states: TT["batch_shape", 2, float] - ) -> TT["batch_shape", "1 + 5 * n_components"]: + def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor: + """Computes the forward pass of the neural network. + + Args: + preprocessed_states: The tensor states of shape (*batch_shape, 2) to compute the forward pass of the neural network. + + Returns the output of the neural network as a tensor of shape (*batch_shape, 1 + 5 * max_n_components). + """ + assert preprocessed_states.shape[-1] == 2 + batch_shape = preprocessed_states.shape[:-1] + if preprocessed_states.ndim != 2: raise ValueError( f"preprocessed_states should be of shape [B, 2], got {preprocessed_states.shape}" @@ -550,6 +645,7 @@ def forward( desired_out[..., 1 + self._n_comp_max :] ) + assert desired_out.shape == batch_shape + (1 + 5 * self._n_comp_max,) return desired_out @@ -591,14 +687,23 @@ def __init__( self.n_components = n_components - def forward( - self, preprocessed_states: TT["batch_shape", 2, float] - ) -> TT["batch_shape", "3 * n_components"]: + def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor: + """Computes the forward pass of the neural network. + + Args: + preprocessed_states: The tensor states of shape (*batch_shape, 2) to compute the forward pass of the neural network. + + Returns the output of the neural network as a tensor of shape (*batch_shape, 3 * n_components). + """ + assert preprocessed_states.shape[-1] == 2 + batch_shape = preprocessed_states.shape[:-1] + out = super().forward(preprocessed_states) # Apply sigmoid to all except the dimensions between 0 and self.n_components. out[..., self.n_components :] = torch.sigmoid(out[..., self.n_components :]) + assert out.shape == batch_shape + (3 * self.n_components,) return out @@ -609,9 +714,14 @@ def __init__(self, logZ_value: torch.Tensor, **kwargs: Any): super().__init__(**kwargs) self.logZ_value = nn.Parameter(logZ_value) - def forward( - self, preprocessed_states: TT["batch_shape", "input_dim", float] - ) -> TT["batch_shape", "output_dim", float]: + def forward( self, preprocessed_states: torch.Tensor) -> torch.Tensor: + """Computes the forward pass of the neural network. + + Args: + preprocessed_states: The tensor states of shape (*batch_shape, input_dim) to compute the forward pass of the neural network. + + Returns the output of the neural network as a tensor of shape (*batch_shape, output_dim). + """ out = super().forward(preprocessed_states) idx_s0 = torch.all(preprocessed_states == 0.0, 1) out[idx_s0] = self.logZ_value @@ -628,22 +738,25 @@ class BoxPBUniform(torch.nn.Module): input_dim = 2 - def forward( - self, preprocessed_states: TT["batch_shape", 2, float] - ) -> TT["batch_shape", 3]: + def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor: + """Computes the forward pass of the neural network. + + Args: + preprocessed_states: The tensor states of shape (*batch_shape, 2) to compute the forward pass of the neural network. + + Returns a tensor of shape (*batch_shape, 3) filled by ones. + """ # return (1, 1, 1) for all states, thus the "+ (3,)". - return torch.ones( - preprocessed_states.shape[:-1] + (3,), device=preprocessed_states.device - ) + assert preprocessed_states.shape[-1] == 2 + batch_shape = preprocessed_states.shape[:-1] + return torch.ones(batch_shape + (3,), device=preprocessed_states.device) -def split_PF_module_output( - output: TT["batch_shape", "output_dim", float], n_comp_max: int -): +def split_PF_module_output(output: torch.Tensor, n_comp_max: int): """Splits the module output into the expected parameter sets. Args: - output: the module_output from the P_F model. + output: the module_output from the P_F model as a tensor of shape (*batch_shape, output_dim). n_comp_max: the larger number of the two n_components and n_components_s0. Returns: @@ -703,8 +816,15 @@ def expected_output_dim(self) -> int: return 1 + 5 * self._n_comp_max def to_probability_distribution( - self, states: States, module_output: TT["batch_shape", "output_dim", float] + self, states: States, module_output: torch.Tensor ) -> Distribution: + """Converts the module output to a probability distribution. + + Args: + states: the states for which to convert the module output to a probability distribution. + module_output: the output of the module for the states as a tensor of shape (*batch_shape, output_dim). + + Returns the probability distribution for the states.""" # First, we verify that the batch shape of states is 1 assert len(states.batch_shape) == 1 @@ -783,8 +903,16 @@ def expected_output_dim(self) -> int: return 3 * self.n_components def to_probability_distribution( - self, states: States, module_output: TT["batch_shape", "output_dim", float] + self, states: States, module_output: torch.Tensor ) -> Distribution: + """Converts the module output to a probability distribution. + + Args: + states: the states for which to convert the module output to a probability distribution. + module_output: the output of the module for the states as a tensor of shape (*batch_shape, output_dim). + + Returns the probability distribution for the states. + """ # First, we verify that the batch shape of states is 1 assert len(states.batch_shape) == 1 mixture_logits, alpha, beta = torch.split( diff --git a/src/gfn/gym/helpers/preprocessors.py b/src/gfn/gym/helpers/preprocessors.py index bb21e0ff..99721fe5 100644 --- a/src/gfn/gym/helpers/preprocessors.py +++ b/src/gfn/gym/helpers/preprocessors.py @@ -3,7 +3,6 @@ import torch from einops import rearrange from torch.nn.functional import one_hot -from torchtyping import TensorType as TT from gfn.preprocessors import Preprocessor from gfn.states import States @@ -13,19 +12,20 @@ class OneHotPreprocessor(Preprocessor): def __init__( self, n_states: int, - get_states_indices: Callable[[States], TT["batch_shape", "input_dim"]], + get_states_indices: Callable[[States], torch.Tensor], ) -> None: """One Hot Preprocessor for environments with enumerable states (finite number of states). Args: n_states (int): The total number of states in the environment (not including s_f). get_states_indices (Callable[[States], BatchOutputTensor]): function that returns the unique indices of the states. + BatchOutputTensor is a tensor of shape (*batch_shape, input_dim). """ super().__init__(output_dim=n_states) self.get_states_indices = get_states_indices self.output_dim = n_states - def preprocess(self, states): + def preprocess(self, states) -> torch.Tensor: state_indices = self.get_states_indices(states) return one_hot(state_indices, self.output_dim).float() @@ -35,7 +35,7 @@ def __init__( self, height: int, ndim: int, - get_states_indices: Callable[[States], TT["batch_shape", "input_dim"]], + get_states_indices: Callable[[States], torch.Tensor], ) -> None: """K Hot Preprocessor for environments with enumerable states (finite number of states) with a grid structure. @@ -43,6 +43,8 @@ def __init__( height (int): number of unique values per dimension. ndim (int): number of dimensions. get_states_indices (Callable[[States], BatchOutputTensor]): function that returns the unique indices of the states. + BatchOutputTensor is a tensor of shape (*batch_shape, input_dim). + """ super().__init__(output_dim=height * ndim) self.height = height diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index 9d6d7d0f..23d21137 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -6,7 +6,6 @@ import torch from einops import rearrange -from torchtyping import TensorType as TT from gfn.actions import Actions from gfn.env import DiscreteEnv @@ -87,7 +86,6 @@ def __init__( def update_masks(self, states: type[DiscreteStates]) -> None: """Update the masks based on the current states.""" - states.set_default_typing() # Not allowed to take any action beyond the environment height, but # allow early termination. states.set_nonexit_action_masks( @@ -98,29 +96,57 @@ def update_masks(self, states: type[DiscreteStates]) -> None: def make_random_states_tensor( self, batch_shape: Tuple[int, ...] - ) -> TT["batch_shape", "state_shape", torch.float]: - """Creates a batch of random states.""" + ) -> torch.Tensor: + """Creates a batch of random states. + + Args: + batch_shape: Tuple indicating the shape of the batch. + + Returns the batch of random states as tensor of shape (*batch_shape, *state_shape).""" return torch.randint( 0, self.height, batch_shape + self.s0.shape, device=self.device ) def step( self, states: DiscreteStates, actions: Actions - ) -> TT["batch_shape", "state_shape", torch.float]: + ) -> torch.Tensor: + """Take a step in the environment. + + Args: + states: The current states. + actions: The actions to take. + + Returns the new states after taking the actions as a tensor of shape (*batch_shape, *state_shape). + """ new_states_tensor = states.tensor.scatter(-1, actions.tensor, 1, reduce="add") + assert new_states_tensor.shape == states.tensor.shape return new_states_tensor def backward_step( self, states: DiscreteStates, actions: Actions - ) -> TT["batch_shape", "state_shape", torch.float]: + ) -> torch.Tensor: + """Take a step in the environment in the backward direction. + + Args: + states: The current states. + actions: The actions to take. + + Returns the new states after taking the actions as a tensor of shape (*batch_shape, *state_shape). + """ new_states_tensor = states.tensor.scatter(-1, actions.tensor, -1, reduce="add") + assert new_states_tensor.shape == states.tensor.shape return new_states_tensor - def reward(self, final_states: DiscreteStates) -> TT["batch_shape", torch.float]: + def reward(self, final_states: DiscreteStates) -> torch.Tensor: r"""In the normal setting, the reward is: R(s) = R_0 + 0.5 \prod_{d=1}^D \mathbf{1} \left( \left\lvert \frac{s^d}{H-1} - 0.5 \right\rvert \in (0.25, 0.5] \right) + 2 \prod_{d=1}^D \mathbf{1} \left( \left\lvert \frac{s^d}{H-1} - 0.5 \right\rvert \in (0.3, 0.4) \right) + + Args: + final_states: The final states. + + Returns the reward as a tensor of shape `batch_shape`. """ final_states_raw = final_states.tensor R0, R1, R2 = (self.R0, self.R1, self.R2) @@ -133,22 +159,36 @@ def reward(self, final_states: DiscreteStates) -> TT["batch_shape", torch.float] pdf_input = ax * 5 pdf = 1.0 / (2 * torch.pi) ** 0.5 * torch.exp(-(pdf_input**2) / 2) reward = R0 + ((torch.cos(ax * 50) + 1) * pdf).prod(-1) * R1 + + assert reward.shape == final_states.batch_shape return reward def get_states_indices( self, states: DiscreteStates - ) -> TT["batch_shape", torch.long]: + ) -> torch.Tensor: + """Get the indices of the states in the canonical ordering. + + Args: + states: The states to get the indices of. + + Returns the indices of the states in the canonical ordering as a tensor of shape `batch_shape`. + """ states_raw = states.tensor canonical_base = self.height ** torch.arange( self.ndim - 1, -1, -1, device=states_raw.device ) indices = (canonical_base * states_raw).sum(-1).long() + assert indices.shape == states.batch_shape return indices def get_terminating_states_indices( self, states: DiscreteStates - ) -> TT["batch_shape", torch.long]: + ) -> torch.Tensor: + """Get the indices of the terminating states in the canonical ordering. + + Returns the indices of the terminating states in the canonical ordering as a tensor of shape `batch_shape`. + """ return self.get_states_indices(states) @property diff --git a/src/gfn/gym/line.py b/src/gfn/gym/line.py index a0b2534e..90f0aeaf 100644 --- a/src/gfn/gym/line.py +++ b/src/gfn/gym/line.py @@ -2,7 +2,6 @@ import torch from torch.distributions import Normal # TODO: extend to Beta -from torchtyping import TensorType as TT from gfn.actions import Actions from gfn.env import Env @@ -46,20 +45,38 @@ def __init__( def step( self, states: States, actions: Actions - ) -> TT["batch_shape", 2, torch.float]: + ) -> torch.Tensor: + """Take a step in the environment. + + Args: + states: The current states. + actions: The actions to take. + + Returns the new states after taking the actions as a tensor of shape (*batch_shape, 2). + """ states.tensor[..., 0] = states.tensor[..., 0] + actions.tensor.squeeze( -1 ) # x position. states.tensor[..., 1] = states.tensor[..., 1] + 1 # Step counter. + assert states.tensor.shape == states.batch_shape + (2,) return states.tensor def backward_step( self, states: States, actions: Actions - ) -> TT["batch_shape", 2, torch.float]: + ) -> torch.Tensor: + """Take a step in the environment in the backward direction. + + Args: + states: The current states. + actions: The actions to take. + + Returns the new states after taking the actions as a tensor of shape (*batch_shape, 2). + """ states.tensor[..., 0] = states.tensor[..., 0] - actions.tensor.squeeze( -1 ) # x position. states.tensor[..., 1] = states.tensor[..., 1] - 1 # Step counter. + assert states.tensor.shape == states.batch_shape + (2,) return states.tensor def is_action_valid( @@ -71,13 +88,22 @@ def is_action_valid( return True - def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]: + def log_reward(self, final_states: States) -> torch.Tensor: + """Log reward log of the environment. + + Args: + final_states: The final states of the environment. + + Returns the log reward as a tensor of shape `batch_shape`. + """ s = final_states.tensor[..., 0] log_rewards = torch.empty((len(self.mixture),) + final_states.batch_shape) for i, m in enumerate(self.mixture): log_rewards[i] = m.log_prob(s) - return torch.logsumexp(log_rewards, 0) + log_rewards = torch.logsumexp(log_rewards, 0) + assert log_rewards.shape == final_states.batch_shape + return log_rewards @property def log_partition(self) -> float: diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 14515bab..a0ddfed2 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn from torch.distributions import Categorical, Distribution -from torchtyping import TensorType as TT from gfn.preprocessors import IdentityPreprocessor, Preprocessor from gfn.states import DiscreteStates, States @@ -73,9 +72,13 @@ def __init__( self._output_dim_is_checked = False self.is_backward = is_backward - def forward( - self, input: States | torch.Tensor - ) -> TT["batch_shape", "output_dim", float]: + def forward(self, input: States | torch.Tensor) -> torch.Tensor: + """Forward pass of the module. + + Args: + input: The input to the module, as states or a tensor. + + Returns the output of the module, as a tensor of shape (*batch_shape, output_dim).""" if isinstance(input, States): input = self.preprocessor(input) @@ -95,10 +98,9 @@ def __repr__(self): def expected_output_dim(self) -> int: """Expected output dimension of the module.""" - def check_output_dim( - self, module_output: TT["batch_shape", "output_dim", float] - ) -> None: + def check_output_dim(self, module_output: torch.Tensor) -> None: """Check that the output of the module has the correct shape. Raises an error if not.""" + assert module_output.dtype == torch.float if module_output.shape[-1] != self.expected_output_dim(): raise ValueError( f"{self.__class__.__name__} output dimension should be {self.expected_output_dim()}" @@ -108,7 +110,7 @@ def check_output_dim( def to_probability_distribution( self, states: States, - module_output: TT["batch_shape", "output_dim", float], + module_output: torch.Tensor, **policy_kwargs: Any, ) -> Distribution: """Transform the output of the module into a probability distribution. @@ -119,6 +121,13 @@ def to_probability_distribution( policy from a module's outputs. See `DiscretePolicyEstimator` for an example using a categorical distribution, but note this can be done for all continuous distributions as well. + + Args: + states: The states to use. + module_output: The output of the module as a tensor of shape (*batch_shape, output_dim). + **policy_kwargs: Keyword arguments to modify the distribution. + + Returns a distribution object. """ raise NotImplementedError @@ -169,7 +178,7 @@ def expected_output_dim(self) -> int: def to_probability_distribution( self, states: DiscreteStates, - module_output: TT["batch_shape", "output_dim", float], + module_output: torch.Tensor, temperature: float = 1.0, sf_bias: float = 0.0, epsilon: float = 0.0, @@ -179,6 +188,8 @@ def to_probability_distribution( We handle off-policyness using these kwargs. Args: + states: The states to use. + module_output: The output of the module as a tensor of shape (*batch_shape, output_dim). temperature: scalar to divide the logits by before softmax. Does nothing if set to 1.0 (default), in which case it's on policy. sf_bias: scalar to subtract from the exit action logit before dividing by @@ -186,6 +197,8 @@ def to_probability_distribution( on policy. epsilon: with probability epsilon, a random action is chosen. Does nothing if set to 0.0 (default), in which case it's on policy.""" + self.check_output_dim(module_output) + masks = states.backward_masks if self.is_backward else states.forward_masks logits = module_output logits[~masks] = -float("inf") @@ -242,7 +255,15 @@ def __init__( def _forward_trunk( self, states: States, conditioning: torch.Tensor - ) -> TT["batch_shape", "output_dim", float]: + ) -> torch.Tensor: + """Forward pass of the trunk of the module. + + Args: + states: The input states. + conditioning: The conditioning input. + + Returns the output of the trunk of the module, as a tensor of shape (*batch_shape, output_dim). + """ state_out = self.module(self.preprocessor(states)) conditioning_out = self.conditioning_module(conditioning) out = self.final_module(torch.cat((state_out, conditioning_out), -1)) @@ -251,7 +272,15 @@ def _forward_trunk( def forward( self, states: States, conditioning: torch.tensor - ) -> TT["batch_shape", "output_dim", float]: + ) -> torch.Tensor: + """Forward pass of the module. + + Args: + states: The input states. + conditioning: The conditioning input. + + Returns the output of the module, as a tensor of shape (*batch_shape, output_dim). + """ out = self._forward_trunk(states, conditioning) if not self._output_dim_is_checked: @@ -281,7 +310,15 @@ def __init__( def forward( self, states: States, conditioning: torch.tensor - ) -> TT["batch_shape", "output_dim", float]: + ) -> torch.Tensor: + """Forward pass of the module. + + Args: + states: The input states. + conditioning: The tensor for conditioning. + + Returns the output of the module, as a tensor of shape (*batch_shape, output_dim). + """ out = self._forward_trunk(states, conditioning) if not self._output_dim_is_checked: @@ -296,7 +333,16 @@ def expected_output_dim(self) -> int: def to_probability_distribution( self, states: States, - module_output: TT["batch_shape", "output_dim", float], + module_output: torch.Tensor, **policy_kwargs: Any, ) -> Distribution: + """Transform the output of the module into a probability distribution. + + Args: + states: The states to use. + module_output: The output of the module as a tensor of shape (*batch_shape, output_dim). + **policy_kwargs: Keyword arguments to modify the distribution. + + Returns a distribution object. + """ raise NotImplementedError diff --git a/src/gfn/preprocessors.py b/src/gfn/preprocessors.py index c980168d..6754fabd 100644 --- a/src/gfn/preprocessors.py +++ b/src/gfn/preprocessors.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod from typing import Callable -from torchtyping import TensorType as TT - +import torch from gfn.states import States @@ -16,11 +15,21 @@ def __init__(self, output_dim: int) -> None: self.output_dim = output_dim @abstractmethod - def preprocess(self, states: States) -> TT["batch_shape", "input_dim"]: + def preprocess(self, states: States) -> torch.Tensor: + """Transform the states to the input of the neural network. + + Args: + states: The states to preprocess. + + Returns the preprocessed states as a tensor of shape (*batch_shape, output_dim). + """ pass - def __call__(self, states: States) -> TT["batch_shape", "input_dim"]: - return self.preprocess(states) + def __call__(self, states: States) -> torch.Tensor: + """Transform the states to the input of the neural network, calling the preprocess method.""" + out = self.preprocess(states) + assert out.shape[-1] == self.output_dim + return out def __repr__(self): return f"{self.__class__.__name__}, output_dim={self.output_dim}" @@ -30,7 +39,8 @@ class IdentityPreprocessor(Preprocessor): """Simple preprocessor applicable to environments with uni-dimensional states. This is the default preprocessor used.""" - def preprocess(self, states: States) -> TT["batch_shape", "input_dim"]: + def preprocess(self, states: States) -> torch.Tensor: + """Identity preprocessor. Returns the states as they are.""" return ( states.tensor.float() ) # TODO: should we typecast here? not a true identity... @@ -41,16 +51,24 @@ class EnumPreprocessor(Preprocessor): def __init__( self, - get_states_indices: Callable[[States], TT["batch_shape", "input_dim"]], + get_states_indices: Callable[[States], torch.Tensor], ) -> None: """Preprocessor for environments with enumerable states (finite number of states). Each state is represented by a unique integer (>= 0) index. Args: get_states_indices (Callable[[States], BatchOutputTensor]): function that returns the unique indices of the states. + BatchOutputTensor is a tensor of shape (*batch_shape, 1). """ super().__init__(output_dim=1) self.get_states_indices = get_states_indices - def preprocess(self, states): + def preprocess(self, states) -> torch.Tensor: + """Preprocess the states by returning their unique indices. + + Args: + states: The states to preprocess. + + Returns the unique indices of the states as a tensor of shape `batch_shape`. + """ return self.get_states_indices(states).long().unsqueeze(-1) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 2712c1f5..a48ac5bb 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -2,7 +2,6 @@ from typing import List, Optional, Tuple, Any import torch -from torchtyping import TensorType as TT from gfn.actions import Actions from gfn.containers import Trajectories @@ -38,8 +37,8 @@ def sample_actions( **policy_kwargs: Any, ) -> Tuple[ Actions, - TT["batch_shape", torch.float] | None, - TT["batch_shape", torch.float] | None, + torch.Tensor | None, + torch.Tensor | None, ]: """Samples actions from the given states. @@ -67,9 +66,10 @@ def sample_actions( Returns: A tuple of tensors containing: - An Actions object containing the sampled actions. - - A tensor of shape (*batch_shape,) containing the log probabilities of + - An optional tensor of shape `batch_shape` containing the log probabilities of the sampled actions under the probability distribution of the given states. + - An optional tensor of shape `batch_shape` containing the estimator outputs """ # TODO: Should estimators instead ignore None for the conditioning vector? if conditioning is not None: @@ -94,10 +94,11 @@ def sample_actions( log_probs = None actions = env.actions_from_tensor(actions) - if not save_estimator_outputs: estimator_output = None + assert log_probs is None or log_probs.shape == actions.batch_shape + # assert estimator_output is None or estimator_output.shape == actions.batch_shape TODO: check expected shape return actions, log_probs, estimator_output def sample_trajectories( @@ -159,8 +160,8 @@ def sample_trajectories( ) trajectories_states: List[States] = [deepcopy(states)] - trajectories_actions: List[TT["n_trajectories", torch.long]] = [] - trajectories_logprobs: List[TT["n_trajectories", torch.float]] = [] + trajectories_actions: List[torch.Tensor] = [] + trajectories_logprobs: List[torch.Tensor] = [] trajectories_dones = torch.zeros( n_trajectories, dtype=torch.long, device=device ) diff --git a/src/gfn/states.py b/src/gfn/states.py index fac0ac09..9442e24e 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -3,11 +3,9 @@ from abc import ABC from copy import deepcopy from math import prod -from typing import Callable, ClassVar, List, Optional, Sequence, cast +from typing import Callable, ClassVar, List, Optional, Sequence import torch -from torch import Tensor -from torchtyping import TensorType as TT class States(ABC): @@ -47,21 +45,23 @@ class States(ABC): """ state_shape: ClassVar[tuple[int, ...]] # Shape of one state - s0: ClassVar[TT["state_shape", torch.float]] # Source state of the DAG - sf: ClassVar[ - TT["state_shape", torch.float] - ] # Dummy state, used to pad a batch of states + s0: ClassVar[torch.Tensor] # Source state of the DAG + sf: ClassVar[torch.Tensor] # Dummy state, used to pad a batch of states make_random_states_tensor: Callable = lambda x: (_ for _ in ()).throw( NotImplementedError( "The environment does not support initialization of random states." ) ) - def __init__(self, tensor: TT["batch_shape", "state_shape"]): + def __init__(self, tensor: torch.Tensor): """Initalize the State container with a batch of states. Args: - tensor: Tensor representing a batch of states. + tensor: Tensor of shape (*batch_shape, *state_shape) representing a batch of states. """ + assert self.s0 .shape == self.state_shape + assert self.sf.shape == self.state_shape + assert tensor.shape[-len(self.state_shape) :] == self.state_shape + self.tensor = tensor self.batch_shape = tuple(self.tensor.shape)[: -len(self.state_shape)] self._log_rewards = ( @@ -102,8 +102,14 @@ def from_batch_shape( @classmethod def make_initial_states_tensor( cls, batch_shape: tuple[int] - ) -> TT["batch_shape", "state_shape", torch.float]: - """Makes a tensor with a `batch_shape` of states consisting of $s_0`$s.""" + ) -> torch.Tensor: + """Makes a tensor with a `batch_shape` of states consisting of $s_0`$s. + + Args: + batch_shape: Shape of the batch dimensions. + + Returns a tensor of shape (*batch_shape, *state_shape) with all states equal to $s_0$. + """ state_ndim = len(cls.state_shape) assert cls.s0 is not None and state_ndim is not None return cls.s0.repeat(*batch_shape, *((1,) * state_ndim)) @@ -111,8 +117,13 @@ def make_initial_states_tensor( @classmethod def make_sink_states_tensor( cls, batch_shape: tuple[int] - ) -> TT["batch_shape", "state_shape", torch.float]: - """Makes a tensor with a `batch_shape` of states consisting of $s_f$s.""" + ) -> torch.Tensor: + """Makes a tensor with a `batch_shape` of states consisting of $s_f$s. + + Args: + batch_shape: Shape of the batch dimensions. + + Returns a tensor of shape (*batch_shape, *state_shape) with all states equal to $s_f$.""" state_ndim = len(cls.state_shape) assert cls.sf is not None and state_ndim is not None return cls.sf.repeat(*batch_shape, *((1,) * state_ndim)) @@ -128,7 +139,7 @@ def device(self) -> torch.device: return self.tensor.device def __getitem__( - self, index: int | Sequence[int] | Sequence[bool] | Tensor + self, index: int | Sequence[int] | Sequence[bool] | torch.Tensor ) -> States: """Access particular states of the batch.""" out = self.__class__( @@ -227,34 +238,35 @@ def extend_with_sf(self, required_first_dim: int) -> None: f"extend_with_sf is not implemented for batch shapes {self.batch_shape}" ) - def compare( - self, other: TT["batch_shape", "state_shape", torch.float] - ) -> TT["batch_shape", torch.bool]: + def compare(self, other: torch.tensor) -> torch.Tensor: """Computes elementwise equality between state tensor with an external tensor. Args: - other: Tensor of states to compare to. + other: Tensor with shape (*batch_shape, *state_shape) representing states to compare to. - Returns: Tensor of booleans indicating whether the states are equal to the - states in self. + Returns a tensor of booleans with shape `batch_shape` indicating whether the states are equal + to the states in self. """ + assert other.shape == self.batch_shape + self.state_shape out = self.tensor == other state_ndim = len(self.__class__.state_shape) for _ in range(state_ndim): out = out.all(dim=-1) + + assert out.shape == self.batch_shape return out @property - def is_initial_state(self) -> TT["batch_shape", torch.bool]: - """Return a tensor that is True for states that are $s_0$ of the DAG.""" + def is_initial_state(self) -> torch.Tensor: + """Returns a tensor of shape `batch_shape` that is True for states that are $s_0$ of the DAG.""" source_states_tensor = self.__class__.s0.repeat( *self.batch_shape, *((1,) * len(self.__class__.state_shape)) ) return self.compare(source_states_tensor) @property - def is_sink_state(self) -> TT["batch_shape", torch.bool]: - """Return a tensor that is True for states that are $s_f$ of the DAG.""" + def is_sink_state(self) -> torch.Tensor: + """Returns a tensor of shape `batch_shape` that is True for states that are $s_f$ of the DAG.""" # TODO: self.__class__.sf == self.tensor -- or something similar? sink_states = self.__class__.sf.repeat( *self.batch_shape, *((1,) * len(self.__class__.state_shape)) @@ -262,11 +274,18 @@ def is_sink_state(self) -> TT["batch_shape", torch.bool]: return self.compare(sink_states) @property - def log_rewards(self) -> TT["batch_shape", torch.float]: + def log_rewards(self) -> torch.Tensor: + """Returns the log rewards of the states as tensor of shape `batch_shape`.""" return self._log_rewards @log_rewards.setter - def log_rewards(self, log_rewards: TT["batch_shape", torch.float]) -> None: + def log_rewards(self, log_rewards: torch.Tensor) -> None: + """Sets the log rewards of the states. + + Args: + log_rewards: Tensor of shape `batch_shape` representing the log rewards of the states. + """ + assert log_rewards.shape == self.batch_shape self._log_rewards = log_rewards def sample(self, n_samples: int) -> States: @@ -292,39 +311,40 @@ class DiscreteStates(States, ABC): def __init__( self, - tensor: TT["batch_shape", "state_shape", torch.float], - forward_masks: Optional[TT["batch_shape", "n_actions", torch.bool]] = None, - backward_masks: Optional[TT["batch_shape", "n_actions - 1", torch.bool]] = None, + tensor: torch.Tensor, + forward_masks: Optional[torch.Tensor] = None, + backward_masks: Optional[torch.Tensor] = None, ) -> None: """Initalize a DiscreteStates container with a batch of states and masks. Args: - tensor: A batch of states. - forward_masks: Initializes a boolean tensor of allowable forward policy - actions. - backward_masks: Initializes a boolean tensor of allowable backward policy - actions. + tensor: A tensor with shape (*batch_shape, *state_shape) representing a batch of states. + forward_masks: Optional boolean tensor tensor with shape (*batch_shape, n_actions) of + allowable forward policy actions. + backward_masks: Optional boolean tensor tensor with shape (*batch_shape, n_actions) of + allowable backward policy actions. """ super().__init__(tensor) + assert tensor.shape == self.batch_shape + self.state_shape # In the usual case, no masks are provided and we produce these defaults. # Note: this **must** be updated externally by the env. - if isinstance(forward_masks, type(None)): + if forward_masks is None: forward_masks = torch.ones( (*self.batch_shape, self.__class__.n_actions), dtype=torch.bool, device=self.__class__.device, ) - if isinstance(backward_masks, type(None)): + if backward_masks is None: backward_masks = torch.ones( (*self.batch_shape, self.__class__.n_actions - 1), dtype=torch.bool, device=self.__class__.device, ) + assert forward_masks.shape == (*self.batch_shape, self.n_actions) + assert backward_masks.shape == (*self.batch_shape, self.n_actions - 1) - # Ensures typecasting is consistent no matter what is submitted to init. - self.forward_masks = cast(torch.Tensor, forward_masks) # TODO: Required? - self.backward_masks = cast(torch.Tensor, backward_masks) # TODO: Required? - self.set_default_typing() + self.forward_masks = forward_masks + self.backward_masks = backward_masks def clone(self) -> States: """Returns a clone of the current instance.""" @@ -334,17 +354,6 @@ def clone(self) -> States: self.backward_masks, ) - def set_default_typing(self) -> None: - """A convienience function for default typing of the masks.""" - self.forward_masks = cast( - TT["batch_shape", "n_actions", torch.bool], - self.forward_masks, - ) - self.backward_masks = cast( - TT["batch_shape", "n_actions - 1", torch.bool], - self.backward_masks, - ) - def _check_both_forward_backward_masks_exist(self): assert self.forward_masks is not None and self.backward_masks is not None diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index 78cb84cb..ecf541b7 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -1,6 +1,5 @@ import torch from torch.distributions import Categorical -from torchtyping import TensorType as TT class UnsqueezedCategorical(Categorical): @@ -18,10 +17,25 @@ class UnsqueezedCategorical(Categorical): should be of shape (batch_shape, 1). """ - def sample(self, sample_shape=torch.Size()) -> TT["sample_shape", 1]: - """Sample actions with an unsqueezed final dimension.""" - return super().sample(sample_shape).unsqueeze(-1) + def sample(self, sample_shape=torch.Size()) -> torch.Tensor: + """Sample actions with an unsqueezed final dimension. + + Args: + sample_shape: The shape of the sample. + + Returns the sampled actions as a tensor of shape (*sample_shape, *batch_shape, 1). + """ + out = super().sample(sample_shape).unsqueeze(-1) + assert out.shape == sample_shape + self._batch_shape + (1,) + return out - def log_prob(self, sample: TT["sample_shape", 1]) -> TT["sample_shape"]: - """Returns the log probabilities of an unsqueezed sample.""" + def log_prob(self, sample: torch.Tensor) -> torch.Tensor: + """Returns the log probabilities of an unsqueezed sample. + + Args: + sample: The sample of for which to compute the log probabilities. + + Returns the log probabilities of the sample as a tensor of shape (*sample_shape, *batch_shape). + """ + assert sample.shape[-1] == 1 return super().log_prob(sample.squeeze(-1)) diff --git a/src/gfn/utils/modules.py b/src/gfn/utils/modules.py index 22790e6e..ccabfd28 100644 --- a/src/gfn/utils/modules.py +++ b/src/gfn/utils/modules.py @@ -5,7 +5,6 @@ import torch import torch.nn as nn from torch.nn.parameter import Parameter -from torchtyping import TensorType as TT class NeuralNet(nn.Module): @@ -57,14 +56,14 @@ def __init__( self.last_layer = nn.Linear(self.trunk.hidden_dim, output_dim) def forward( - self, preprocessed_states: TT["batch_shape", "input_dim", float] - ) -> TT["batch_shape", "output_dim", float]: + self, preprocessed_states: torch.Tensor + ) -> torch.Tensor: """Forward method for the neural network. Args: preprocessed_states: a batch of states appropriately preprocessed for - ingestion by the MLP. - Returns: out, a set of continuous variables. + ingestion by the MLP. The shape of the tensor should be (*batch_shape, input_dim). + Returns: a tensor of shape (*batch_shape, output_dim). """ out = self.trunk(preprocessed_states) out = self.last_layer(out) @@ -107,8 +106,15 @@ def __init__(self, n_states: int, output_dim: int) -> None: self.device = None def forward( - self, preprocessed_states: TT["batch_shape", "input_dim", float] - ) -> TT["batch_shape", "output_dim", float]: + self, preprocessed_states: torch.Tensor + ) -> torch.Tensor: + """Forward method for the tabular policy. + + Args: + preprocessed_states: a batch of states appropriately preprocessed for + ingestion by the tabular policy. The shape of the tensor should be (*batch_shape, 1). + Returns: a tensor of shape (*batch_shape, output_dim). + """ if self.device is None: self.device = preprocessed_states.device self.table = self.table.to(self.device) @@ -138,8 +144,16 @@ def __init__(self, output_dim: int) -> None: self.output_dim = output_dim def forward( - self, preprocessed_states: TT["batch_shape", "input_dim", float] - ) -> TT["batch_shape", "output_dim", float]: + self, preprocessed_states: torch.Tensor + ) -> torch.Tensor: + """Forward method for the uniform distribution. + + Args: + preprocessed_states: a batch of states appropriately preprocessed for + ingestion by the uniform distribution. The shape of the tensor should be (*batch_shape, input_dim). + + Returns: a tensor of shape (*batch_shape, output_dim). + """ out = torch.zeros(*preprocessed_states.shape[:-1], self.output_dim).to( preprocessed_states.device ) diff --git a/src/gfn/utils/training.py b/src/gfn/utils/training.py index 9144154b..40c61f11 100644 --- a/src/gfn/utils/training.py +++ b/src/gfn/utils/training.py @@ -2,14 +2,21 @@ from typing import Dict, Optional import torch -from torchtyping import TensorType as TT from gfn.env import Env from gfn.gflownet import GFlowNet, TBGFlowNet from gfn.states import States -def get_terminating_state_dist_pmf(env: Env, states: States) -> TT["n_states", float]: +def get_terminating_state_dist_pmf(env: Env, states: States) -> torch.Tensor: + """Computes the empirical distribution of the terminating states. + + Args: + env: The environment. + states: The states to compute the distribution of. + + Returns the empirical distribution of the terminating states as a tensor of shape (n_terminating_states,). + """ states_indices = env.get_terminating_states_indices(states).cpu().numpy().tolist() counter = Counter(states_indices) counter_list = [ diff --git a/tutorials/examples/train_line.py b/tutorials/examples/train_line.py index c43115f9..fe9a2863 100644 --- a/tutorials/examples/train_line.py +++ b/tutorials/examples/train_line.py @@ -3,7 +3,6 @@ import torch from torch.distributions import Distribution, Normal # TODO: extend to Beta from torch.distributions.independent import Independent -from torchtyping import TensorType as TT from tqdm import trange from gfn.gflownet import TBGFlowNet # TODO: Extend to SubTBGFlowNet @@ -73,9 +72,9 @@ class ScaledGaussianWithOptionalExit(Distribution): def __init__( self, - states: TT["n_states", 2], # Tensor of [x position, step counter]. - mus: TT["n_states", 1], # Parameter of Gaussian distribution. - scales: TT["n_states", 1], # Parameter of Gaussian distribution. + states: torch.Tensor, # Tensor with shape (n_states, 2) with [x position, step counter] for each state. + mus: torch.Tensor, # Tensor with shape (n_states, 1) with mean of Gaussian distribution for each state. + scales: torch.Tensor, # Tensor with shape (n_states, 1) with scale of Gaussian distribution for each state. backward: bool, n_steps: int = 5, ): @@ -146,17 +145,23 @@ def __init__( activation_fn="elu", ) - def forward( - self, preprocessed_states: TT["batch_shape", 2, float] - ) -> TT["batch_shape", "3"]: - """Calculate the gaussian parameters, applying the bound to sigma.""" - assert preprocessed_states.ndim == 2 + def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor: + """Calculate the gaussian parameters, applying the bound to sigma. + + Args: + preprocessed_states: a tensor of shape (*batch_shape, 2) containing the states. + + Returns a tensor of shape (*batch_shape, 2) containing the mean and variance of the Gaussian distribution.""" + batch_shape, state_dim = preprocessed_states.shape + assert state_dim == 2 + out = super().forward(preprocessed_states) # [..., 2]: represents mean & std. minmax_norm = self.policy_std_max - self.policy_std_min out[..., 1] = ( torch.sigmoid(out[..., 1]) * minmax_norm + self.policy_std_min ) # Scales / Variances. + assert out.shape == (batch_shape, 2) return out @@ -174,9 +179,18 @@ def expected_output_dim(self) -> int: def to_probability_distribution( self, states: States, - module_output: TT["batch_shape", "output_dim", float], + module_output: torch.Tensor, scale_factor=0, # policy_kwarg. ) -> Distribution: + """Converts the output of the neural network to a probability distribution. + + Args: + states: The states to use for the distribution. + module_output: The output of the neural network as a tensor of shape (*batch_shape, output_dim). + scale_factor: The scale factor to use for the distribution. + + Returns a distribution object. + """ assert len(states.batch_shape) == 1 assert module_output.shape == states.batch_shape + (2,) # [locs, scales]. locs, scales = torch.split(module_output, [1, 1], dim=-1)