Skip to content

Commit

Permalink
Merge branch 'master' into hyeok9855/local-search
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Nov 2, 2024
2 parents 1ccf16c + 9c9e1af commit a3af467
Show file tree
Hide file tree
Showing 21 changed files with 187 additions and 157 deletions.
20 changes: 20 additions & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# https://pre-commit.com
# This GitHub Action assumes that the repo contains a valid .pre-commit-config.yaml file.
---
name: pre-commit
on: [push]

permissions:
contents: read

jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- run: pip install .[all]
- run: pre-commit --version
- run: pre-commit run --all-files
18 changes: 9 additions & 9 deletions src/gfn/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ def __init__(self, tensor: torch.Tensor):
Args:
tensor: tensors representing a batch of actions with shape (*batch_shape, *action_shape).
"""
assert tensor.shape[-len(self.action_shape):] == self.action_shape, (
f"Batched actions tensor has shape {tensor.shape}, but the expected action shape is {self.action_shape}."
)
assert (
tensor.shape[-len(self.action_shape) :] == self.action_shape
), f"Batched actions tensor has shape {tensor.shape}, but the expected action shape is {self.action_shape}."

self.tensor = tensor
self.batch_shape = tuple(self.tensor.shape)[:-len(self.action_shape)]
self.batch_shape = tuple(self.tensor.shape)[: -len(self.action_shape)]

@classmethod
def make_dummy_actions(cls, batch_shape: tuple[int]) -> Actions:
Expand Down Expand Up @@ -137,13 +137,13 @@ def compare(self, other: torch.Tensor) -> torch.Tensor:
Args:
other: tensor of actions to compare, with shape (*batch_shape, *action_shape).
Returns: boolean tensor of shape batch_shape indicating whether the actions are
equal.
"""
assert other.shape == self.batch_shape + self.action_shape, (
f"Expected shape {self.batch_shape + self.action_shape}, got {other.shape}."
)
assert (
other.shape == self.batch_shape + self.action_shape
), f"Expected shape {self.batch_shape + self.action_shape}, got {other.shape}."
out = self.tensor == other
n_batch_dims = len(self.batch_shape)

Expand Down
6 changes: 6 additions & 0 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]):

# Our buffer is full and we will prioritize diverse, high reward additions.
else:
if (
self.training_objects.log_rewards is None
or training_objects.log_rewards is None
):
raise ValueError("log_rewards must be defined for prioritized replay.")

# Sort the incoming elements by their logrewards.
ix = torch.argsort(training_objects.log_rewards, descending=True)
training_objects = training_objects[ix]
Expand Down
20 changes: 5 additions & 15 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence, Union, Tuple

from typing import TYPE_CHECKING, Sequence, Tuple, Union

if TYPE_CHECKING:
from gfn.actions import Actions
Expand All @@ -16,11 +15,6 @@
from gfn.utils.common import has_log_probs


def is_tensor(t) -> bool:
"""Checks whether t is a torch.Tensor instance."""
return isinstance(t, torch.Tensor)


# TODO: remove env from this class?
class Trajectories(Container):
"""Container for complete trajectories (starting in $s_0$ and ending in $s_f$).
Expand Down Expand Up @@ -114,7 +108,7 @@ def __init__(
)
else:
log_probs = torch.full(size=(0, 0), fill_value=0, dtype=torch.float)
self.log_probs = log_probs
self.log_probs: torch.Tensor = log_probs

self.estimator_outputs = estimator_outputs
if self.estimator_outputs is not None:
Expand Down Expand Up @@ -188,7 +182,7 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories:
log_rewards = (
self._log_rewards[index] if self._log_rewards is not None else None
)
if is_tensor(self.estimator_outputs):
if self.estimator_outputs is not None:
# TODO: Is there a safer way to index self.estimator_outputs for
# for n-dimensional estimator outputs?
#
Expand Down Expand Up @@ -293,13 +287,9 @@ def extend(self, other: Trajectories) -> None:

# Either set, or append, estimator outputs if they exist in the submitted
# trajectory.
if self.estimator_outputs is None and isinstance(
other.estimator_outputs, torch.Tensor
):
if self.estimator_outputs is None and other.estimator_outputs is not None:
self.estimator_outputs = other.estimator_outputs
elif isinstance(self.estimator_outputs, torch.Tensor) and isinstance(
other.estimator_outputs, torch.Tensor
):
elif self.estimator_outputs is not None and other.estimator_outputs is not None:
batch_shape = self.actions.batch_shape
n_bs = len(batch_shape)

Expand Down
22 changes: 17 additions & 5 deletions src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ def __init__(
if is_done is not None
else torch.full(size=(0,), fill_value=False, dtype=torch.bool)
)
assert self.is_done.shape == (self.n_transitions,) and self.is_done.dtype == torch.bool
assert (
self.is_done.shape == (self.n_transitions,)
and self.is_done.dtype == torch.bool
)

self.next_states = (
next_states
Expand All @@ -96,9 +99,15 @@ def __init__(
and self.states.batch_shape == self.next_states.batch_shape
)
self._log_rewards = log_rewards if log_rewards is not None else torch.zeros(0)
assert self._log_rewards.shape == (self.n_transitions,) and self._log_rewards.dtype == torch.float
assert (
self._log_rewards.shape == (self.n_transitions,)
and self._log_rewards.dtype == torch.float
)
self.log_probs = log_probs if log_probs is not None else torch.zeros(0)
assert self.log_probs.shape == (self.n_transitions,) and self.log_probs.dtype == torch.float
assert (
self.log_probs.shape == (self.n_transitions,)
and self.log_probs.dtype == torch.float
)

@property
def n_transitions(self) -> int:
Expand Down Expand Up @@ -186,8 +195,11 @@ def all_log_rewards(self) -> torch.Tensor:
log_rewards[~is_sink_state, 1] = torch.log(
self.env.reward(self.next_states[~is_sink_state])
)

assert log_rewards.shape == (self.n_transitions, 2) and log_rewards.dtype == torch.float

assert (
log_rewards.shape == (self.n_transitions, 2)
and log_rewards.dtype == torch.float
)
return log_rewards

def __getitem__(self, index: int | Sequence[int]) -> Transitions:
Expand Down
8 changes: 4 additions & 4 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def __init__(
assert s0.shape == state_shape
if sf is None:
sf = torch.full(s0.shape, -float("inf")).to(self.device)
assert sf.shape == state_shape
self.sf = sf
self.sf: torch.Tensor = sf
assert self.sf.shape == state_shape
self.state_shape = state_shape
self.action_shape = action_shape
self.dummy_action = dummy_action
Expand Down Expand Up @@ -381,11 +381,11 @@ def __init__(

# The default dummy action is -1.
if dummy_action is None:
dummy_action = torch.tensor([-1], device=device)
dummy_action: torch.Tensor = torch.tensor([-1], device=device)

# The default exit action index is the final element of the action space.
if exit_action is None:
exit_action = torch.tensor([n_actions - 1], device=device)
exit_action: torch.Tensor = torch.tensor([n_actions - 1], device=device)

assert s0.shape == state_shape
assert dummy_action.shape == action_shape
Expand Down
21 changes: 13 additions & 8 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from abc import ABC, abstractmethod
from typing import Generic, Tuple, TypeVar, Union, Any
from typing import Any, Generic, Tuple, TypeVar, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -211,7 +211,6 @@ def get_pfs_and_pbs(
# Using all non-initial states, calculate the backward policy, and the logprobs
# of those actions.
if trajectories.conditioning is not None:

# We need to index the conditioning vector to broadcast over the states.
cond_dim = (-1,) * len(trajectories.conditioning.shape)
traj_len = trajectories.states.tensor.shape[0]
Expand Down Expand Up @@ -242,8 +241,14 @@ def get_pfs_and_pbs(
log_pb_trajectories_slice[~valid_actions.is_exit] = valid_log_pb_actions
log_pb_trajectories[~trajectories.actions.is_dummy] = log_pb_trajectories_slice

assert log_pf_trajectories.shape == (trajectories.max_length, trajectories.n_trajectories)
assert log_pb_trajectories.shape == (trajectories.max_length, trajectories.n_trajectories)
assert log_pf_trajectories.shape == (
trajectories.max_length,
trajectories.n_trajectories,
)
assert log_pb_trajectories.shape == (
trajectories.max_length,
trajectories.n_trajectories,
)
return log_pf_trajectories, log_pb_trajectories

def get_trajectories_scores(
Expand All @@ -252,15 +257,15 @@ def get_trajectories_scores(
recalculate_all_logprobs: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Given a batch of trajectories, calculate forward & backward policy scores.
Args:
trajectories: Trajectories to evaluate.
recalculate_all_logprobs: Whether to re-evaluate all logprobs.
Returns: A tuple of float tensors of shape (n_trajectories,)
containing the total log_pf, total log_pb, and the total
log-likelihood of the trajectories.
"""
log_pf_trajectories, log_pb_trajectories = self.get_pfs_and_pbs(
trajectories, recalculate_all_logprobs=recalculate_all_logprobs
Expand All @@ -279,7 +284,7 @@ def get_trajectories_scores(
torch.isinf(total_log_pb_trajectories)
):
raise ValueError("Infinite logprobs found")

assert total_log_pf_trajectories.shape == (trajectories.n_trajectories,)
assert total_log_pb_trajectories.shape == (trajectories.n_trajectories,)
return (
Expand Down
6 changes: 3 additions & 3 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from gfn.containers import Trajectories, Transitions
from gfn.env import Env
from gfn.gflownet.base import PFBasedGFlowNet
from gfn.modules import GFNModule, ScalarEstimator, ConditionalScalarEstimator
from gfn.modules import ConditionalScalarEstimator, GFNModule, ScalarEstimator
from gfn.utils.common import has_log_probs
from gfn.utils.handlers import (
has_conditioning_exception_handler,
Expand Down Expand Up @@ -91,7 +91,7 @@ def get_scores(
- If transitions have log_probs attribute, use them - this is usually for on-policy learning
- Else, re-evaluate the log_probs using the current self.pf - this is usually for
off-policy learning with replay buffer
Returns: A tuple of three tensors of shapes (n_transitions,), representing the
log probabilities of the actions, the log probabilities of the backward actions, and th scores.
Expand Down Expand Up @@ -194,7 +194,7 @@ def get_scores(

assert valid_log_pf_actions.shape == (transitions.n_transitions,)
assert log_pb_actions.shape == (transitions.n_transitions,)
assert scores.shape == (transitions.n_transitions,)
assert scores.shape == (transitions.n_transitions,)
return valid_log_pf_actions, log_pb_actions, scores

def loss(self, env: Env, transitions: Transitions) -> torch.Tensor:
Expand Down
11 changes: 6 additions & 5 deletions src/gfn/gflownet/flow_matching.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from typing import Tuple, Any, Union
from typing import Any, Tuple, Union

import torch

from gfn.containers import Trajectories
from gfn.env import Env
from gfn.gflownet.base import GFlowNet
from gfn.modules import DiscretePolicyEstimator, ConditionalDiscretePolicyEstimator
from gfn.modules import ConditionalDiscretePolicyEstimator, DiscretePolicyEstimator
from gfn.samplers import Sampler
from gfn.states import DiscreteStates, States
from gfn.utils.handlers import (
no_conditioning_exception_handler,
has_conditioning_exception_handler,
no_conditioning_exception_handler,
)


Expand Down Expand Up @@ -109,7 +109,6 @@ def flow_matching_loss(
)

if conditioning is not None:

# Mask out only valid conditioning elements.
valid_backward_conditioning = conditioning[valid_backward_mask]
valid_forward_conditioning = conditioning[valid_forward_mask]
Expand Down Expand Up @@ -204,7 +203,9 @@ def loss(
)
return fm_loss + self.alpha * rm_loss

def to_training_samples(self, trajectories: Trajectories) -> Union[
def to_training_samples(
self, trajectories: Trajectories
) -> Union[
Tuple[DiscreteStates, DiscreteStates, torch.Tensor, torch.Tensor],
Tuple[DiscreteStates, DiscreteStates, None, None],
Tuple[States, States, torch.Tensor, torch.Tensor],
Expand Down
Loading

0 comments on commit a3af467

Please sign in to comment.