Skip to content

Commit

Permalink
Merge branch 'master' into hyeok9855/minor-refactorings
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Oct 29, 2024
2 parents 8512dce + b3bae95 commit 5a4198e
Show file tree
Hide file tree
Showing 24 changed files with 943 additions and 390 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ einops = ">=0.6.1"
numpy = ">=1.21.2"
python = "^3.10"
torch = ">=1.9.0"
torchtyping = ">=0.1.4"

# dev dependencies.
black = { version = "24.3", optional = true }
Expand Down
39 changes: 20 additions & 19 deletions src/gfn/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import ClassVar, Sequence

import torch
from torchtyping import TensorType as TT


class Actions(ABC):
Expand All @@ -23,23 +22,22 @@ class Actions(ABC):
# The following class variable represents the shape of a single action.
action_shape: ClassVar[tuple[int, ...]] # All actions need to have the same shape.
# The following class variable is padded to shorter trajectories.
dummy_action: ClassVar[TT["action_shape"]] # Dummy action for the environment.
dummy_action: ClassVar[torch.Tensor] # Dummy action for the environment.
# The following class variable corresponds to $s \rightarrow s_f$ transitions.
exit_action: ClassVar[TT["action_shape"]] # Action to exit the environment.
exit_action: ClassVar[torch.Tensor] # Action to exit the environment.

def __init__(self, tensor: TT["batch_shape", "action_shape"]):
def __init__(self, tensor: torch.Tensor):
"""Initialize actions from a tensor.
Args:
tensor: tensor of actions
tensor: tensors representing a batch of actions with shape (*batch_shape, *action_shape).
"""
self.tensor = tensor
assert len(tensor.shape) >= len(self.action_shape), (
f"Actions tensor has shape {tensor.shape}, "
f"but the action shape is {self.action_shape}."
# Ensure the tensor has all action dimensions.
assert tensor.shape[-len(self.action_shape):] == self.action_shape, (
f"Batched actions tensor has shape {tensor.shape}, but the expected action shape is {self.action_shape}."
)
self.batch_shape = tuple(self.tensor.shape)[: -len(self.action_shape)]

self.tensor = tensor
self.batch_shape = tuple(self.tensor.shape)[:-len(self.action_shape)]

@classmethod
def make_dummy_actions(cls, batch_shape: tuple[int]) -> Actions:
Expand Down Expand Up @@ -134,35 +132,38 @@ def extend_with_dummy_actions(self, required_first_dim: int) -> None:
"extend_with_dummy_actions is only implemented for bi-dimensional actions."
)

def compare(
self, other: TT["batch_shape", "action_shape"]
) -> TT["batch_shape", torch.bool]:
def compare(self, other: torch.Tensor) -> torch.Tensor:
"""Compares the actions to a tensor of actions.
Args:
other: tensor of actions
other: tensor of actions to compare, with shape (*batch_shape, *action_shape).
Returns: boolean tensor of shape batch_shape indicating whether the actions are
equal.
"""
assert other.shape == self.batch_shape + self.action_shape, (
f"Expected shape {self.batch_shape + self.action_shape}, got {other.shape}."
)
out = self.tensor == other
n_batch_dims = len(self.batch_shape)

# Flattens all action dims, which we reduce all over.
out = out.flatten(start_dim=n_batch_dims).all(dim=-1)

assert out.dtype == torch.bool and out.shape == self.batch_shape
return out

@property
def is_dummy(self) -> TT["batch_shape", torch.bool]:
"""Returns a boolean tensor indicating whether the actions are dummy actions."""
def is_dummy(self) -> torch.Tensor:
"""Returns a boolean tensor of shape `batch_shape` indicating whether the actions are dummy actions."""
dummy_actions_tensor = self.__class__.dummy_action.repeat(
*self.batch_shape, *((1,) * len(self.__class__.action_shape))
)
return self.compare(dummy_actions_tensor)

@property
def is_exit(self) -> TT["batch_shape", torch.bool]:
"""Returns a boolean tensor indicating whether the actions are exit actions."""
def is_exit(self) -> torch.Tensor:
"""Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions."""
exit_actions_tensor = self.__class__.exit_action.repeat(
*self.batch_shape, *((1,) * len(self.__class__.action_shape))
)
Expand Down
78 changes: 49 additions & 29 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

import numpy as np
import torch
from torch import Tensor
from torchtyping import TensorType as TT

from gfn.containers.base import Container
from gfn.containers.transitions import Transitions
Expand All @@ -20,7 +18,7 @@

def is_tensor(t) -> bool:
"""Checks whether t is a torch.Tensor instance."""
return isinstance(t, Tensor)
return isinstance(t, torch.Tensor)


# TODO: remove env from this class?
Expand All @@ -40,10 +38,10 @@ class Trajectories(Container):
env: The environment in which the trajectories are defined.
states: The states of the trajectories.
actions: The actions of the trajectories.
when_is_done: The time step at which each trajectory ends.
when_is_done: Tensor of shape (n_trajectories,) indicating the time step at which each trajectory ends.
is_backward: Whether the trajectories are backward or forward.
log_rewards: The log_rewards of the trajectories.
log_probs: The log probabilities of the trajectories' actions.
log_rewards: Tensor of shape (n_trajectories,) containing the log rewards of the trajectories.
log_probs: Tensor of shape (max_length, n_trajectories) indicating the log probabilities of the trajectories' actions.
"""

Expand All @@ -53,23 +51,24 @@ def __init__(
states: States | None = None,
conditioning: torch.Tensor | None = None,
actions: Actions | None = None,
when_is_done: TT["n_trajectories", torch.long] | None = None,
when_is_done: torch.Tensor | None = None,
is_backward: bool = False,
log_rewards: TT["n_trajectories", torch.float] | None = None,
log_probs: TT["max_length", "n_trajectories", torch.float] | None = None,
estimator_outputs: TT["batch_shape", "output_dim", torch.float] | None = None,
log_rewards: torch.Tensor | None = None,
log_probs: torch.Tensor | None = None,
estimator_outputs: torch.Tensor | None = None,
) -> None:
"""
Args:
env: The environment in which the trajectories are defined.
states: The states of the trajectories.
actions: The actions of the trajectories.
when_is_done: The time step at which each trajectory ends.
when_is_done: Tensor of shape (n_trajectories,) indicating the time step at which each trajectory ends.
is_backward: Whether the trajectories are backward or forward.
log_rewards: The log_rewards of the trajectories.
log_probs: The log probabilities of the trajectories' actions.
estimator_outputs: When forward sampling off-policy for an n-step
trajectory, n forward passes will be made on some function approximator,
log_rewards: Tensor of shape (n_trajectories,) containing the log rewards of the trajectories.
log_probs: Tensor of shape (max_length, n_trajectories) indicating the log probabilities of the trajectories' actions.
estimator_outputs: Tensor of shape (batch_shape, output_dim).
When forward sampling off-policy for an n-step trajectory,
n forward passes will be made on some function approximator,
which may need to be re-used (for example, for evaluating PF). To avoid
duplicated effort, the outputs of the forward passes can be stored here.
Expand All @@ -93,17 +92,25 @@ def __init__(
if when_is_done is not None
else torch.full(size=(0,), fill_value=-1, dtype=torch.long)
)
assert self.when_is_done.shape == (self.n_trajectories,) and self.when_is_done.dtype == torch.long

self._log_rewards = (
log_rewards
if log_rewards is not None
else torch.full(size=(0,), fill_value=0, dtype=torch.float)
)
self.log_probs = (
log_probs
if log_probs is not None
else torch.full(size=(0, 0), fill_value=0, dtype=torch.float)
)
assert self._log_rewards.shape == (self.n_trajectories,) and self._log_rewards.dtype == torch.float

if log_probs is not None:
assert log_probs.shape == (self.max_length, self.n_trajectories) and log_probs.dtype == torch.float
else:
log_probs = torch.full(size=(0, 0), fill_value=0, dtype=torch.float)
self.log_probs = log_probs

self.estimator_outputs = estimator_outputs
if self.estimator_outputs is not None:
# assert self.estimator_outputs.shape[:len(self.states.batch_shape)] == self.states.batch_shape TODO: check why fails
assert self.estimator_outputs.dtype == torch.float

def __repr__(self) -> str:
states = self.states.tensor.transpose(0, 1)
Expand Down Expand Up @@ -142,7 +149,8 @@ def last_states(self) -> States:
return self.states[self.when_is_done - 1, torch.arange(self.n_trajectories)]

@property
def log_rewards(self) -> TT["n_trajectories", torch.float] | None:
def log_rewards(self) -> torch.Tensor | None:
"""Returns the log rewards of the trajectories as a tensor of shape (n_trajectories,)."""
if self._log_rewards is not None:
assert self._log_rewards.shape == (self.n_trajectories,)
return self._log_rewards
Expand Down Expand Up @@ -200,13 +208,23 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories:

@staticmethod
def extend_log_probs(
log_probs: TT["max_length", "n_trajectories", torch.float], new_max_length: int
) -> TT["max_max_length", "n_trajectories", torch.float]:
"""Extend the log_probs matrix by adding 0 until the required length is reached."""
if log_probs.shape[0] >= new_max_length:
log_probs: torch.Tensor, new_max_length: int
) -> torch.Tensor:
"""Extend the log_probs matrix by adding 0 until the required length is reached.
Args:
log_probs: The log_probs tensor of shape (max_length, n_trajectories) to extend.
new_max_length: The new length of the log_probs tensor.
Returns: The extended log_probs tensor of shape (new_max_length, n_trajectories).
"""

max_length, n_trajectories = log_probs.shape
if max_length >= new_max_length:
return log_probs
else:
return torch.cat(
new_log_probs = torch.cat(
(
log_probs,
torch.full(
Expand All @@ -221,6 +239,8 @@ def extend_log_probs(
),
dim=0,
)
assert new_log_probs.shape == (new_max_length, n_trajectories)
return new_log_probs

def extend(self, other: Trajectories) -> None:
"""Extend the trajectories with another set of trajectories.
Expand Down Expand Up @@ -267,11 +287,11 @@ def extend(self, other: Trajectories) -> None:
# Either set, or append, estimator outputs if they exist in the submitted
# trajectory.
if self.estimator_outputs is None and isinstance(
other.estimator_outputs, Tensor
other.estimator_outputs, torch.Tensor
):
self.estimator_outputs = other.estimator_outputs
elif isinstance(self.estimator_outputs, Tensor) and isinstance(
other.estimator_outputs, Tensor
elif isinstance(self.estimator_outputs, torch.Tensor) and isinstance(
other.estimator_outputs, torch.Tensor
):
batch_shape = self.actions.batch_shape
n_bs = len(batch_shape)
Expand Down
30 changes: 21 additions & 9 deletions src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING, Sequence

import torch
from torchtyping import TensorType as TT

if TYPE_CHECKING:
from gfn.actions import Actions
Expand Down Expand Up @@ -36,11 +35,11 @@ def __init__(
states: States | None = None,
conditioning: torch.Tensor | None = None,
actions: Actions | None = None,
is_done: TT["n_transitions", torch.bool] | None = None,
is_done: torch.Tensor | None = None,
next_states: States | None = None,
is_backward: bool = False,
log_rewards: TT["n_transitions", torch.float] | None = None,
log_probs: TT["n_transitions", torch.float] | None = None,
log_rewards: torch.Tensor | None = None,
log_probs: torch.Tensor | None = None,
):
"""Instantiates a container for transitions.
Expand All @@ -52,14 +51,14 @@ def __init__(
states: States object with uni-dimensional `batch_shape`, representing the
parents of the transitions.
actions: Actions chosen at the parents of each transitions.
is_done: Whether the action is the exit action.
is_done: Tensor of shape (n_transitions,) indicating whether the action is the exit action.
next_states: States object with uni-dimensional `batch_shape`, representing
the children of the transitions.
is_backward: Whether the transitions are backward transitions (i.e.
`next_states` is the parent of states).
log_rewards: The log-rewards of the transitions (using a default value like
log_rewards: Tensor of shape (n_transitions,) containing the log-rewards of the transitions (using a default value like
`-float('inf')` for non-terminating transitions).
log_probs: The log-probabilities of the actions.
log_probs: Tensor of shape (n_transitions,) containing the log-probabilities of the actions.
Raises:
AssertionError: If states and next_states do not have matching
Expand All @@ -78,11 +77,15 @@ def __init__(
self.actions = (
actions if actions is not None else env.actions_from_batch_shape((0,))
)
assert self.actions.batch_shape == self.states.batch_shape

self.is_done = (
is_done
if is_done is not None
else torch.full(size=(0,), fill_value=False, dtype=torch.bool)
)
assert self.is_done.shape == (self.n_transitions,) and self.is_done.dtype == torch.bool

self.next_states = (
next_states
if next_states is not None
Expand All @@ -93,7 +96,9 @@ def __init__(
and self.states.batch_shape == self.next_states.batch_shape
)
self._log_rewards = log_rewards if log_rewards is not None else torch.zeros(0)
assert self._log_rewards.shape == (self.n_transitions,) and self._log_rewards.dtype == torch.float
self.log_probs = log_probs if log_probs is not None else torch.zeros(0)
assert self.log_probs.shape == (self.n_transitions,) and self.log_probs.dtype == torch.float

@property
def n_transitions(self) -> int:
Expand Down Expand Up @@ -124,7 +129,8 @@ def last_states(self) -> States:
return self.states[self.is_done]

@property
def log_rewards(self) -> TT["n_transitions", torch.float] | None:
def log_rewards(self) -> torch.Tensor | None:
"""Compute the tensor of shape (n_transitions,) containing the log rewards for the transitions."""
if self._log_rewards is not None:
return self._log_rewards
if self.is_backward:
Expand All @@ -143,13 +149,17 @@ def log_rewards(self) -> TT["n_transitions", torch.float] | None:
return log_rewards

@property
def all_log_rewards(self) -> TT["n_transitions", 2, torch.float]:
def all_log_rewards(self) -> torch.Tensor:
"""Calculate all log rewards for the transitions.
This is applicable to environments where all states are terminating. This
function evaluates the rewards for all transitions that do not end in the sink
state. This is useful for the Modified Detailed Balance loss.
Returns:
log_rewards: Tensor of shape (n_transitions, 2) containing the log rewards
for the transitions.
Raises:
NotImplementedError: when used for backward transitions.
"""
Expand All @@ -176,6 +186,8 @@ def all_log_rewards(self) -> TT["n_transitions", 2, torch.float]:
log_rewards[~is_sink_state, 1] = torch.log(
self.env.reward(self.next_states[~is_sink_state])
)

assert log_rewards.shape == (self.n_transitions, 2) and log_rewards.dtype == torch.float
return log_rewards

def __getitem__(self, index: int | Sequence[int]) -> Transitions:
Expand Down
Loading

0 comments on commit 5a4198e

Please sign in to comment.