Skip to content

Commit

Permalink
Fix formatting (#207)
Browse files Browse the repository at this point in the history
* fix formatting

* remove extra quote
  • Loading branch information
younik authored Oct 31, 2024
1 parent 57cc269 commit 0d593d8
Show file tree
Hide file tree
Showing 18 changed files with 160 additions and 137 deletions.
18 changes: 9 additions & 9 deletions src/gfn/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ def __init__(self, tensor: torch.Tensor):
Args:
tensor: tensors representing a batch of actions with shape (*batch_shape, *action_shape).
"""
assert tensor.shape[-len(self.action_shape):] == self.action_shape, (
f"Batched actions tensor has shape {tensor.shape}, but the expected action shape is {self.action_shape}."
)
assert (
tensor.shape[-len(self.action_shape) :] == self.action_shape
), f"Batched actions tensor has shape {tensor.shape}, but the expected action shape is {self.action_shape}."

self.tensor = tensor
self.batch_shape = tuple(self.tensor.shape)[:-len(self.action_shape)]
self.batch_shape = tuple(self.tensor.shape)[: -len(self.action_shape)]

@classmethod
def make_dummy_actions(cls, batch_shape: tuple[int]) -> Actions:
Expand Down Expand Up @@ -137,13 +137,13 @@ def compare(self, other: torch.Tensor) -> torch.Tensor:
Args:
other: tensor of actions to compare, with shape (*batch_shape, *action_shape).
Returns: boolean tensor of shape batch_shape indicating whether the actions are
equal.
"""
assert other.shape == self.batch_shape + self.action_shape, (
f"Expected shape {self.batch_shape + self.action_shape}, got {other.shape}."
)
assert (
other.shape == self.batch_shape + self.action_shape
), f"Expected shape {self.batch_shape + self.action_shape}, got {other.shape}."
out = self.tensor == other
n_batch_dims = len(self.batch_shape)

Expand Down
28 changes: 17 additions & 11 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence, Union, Tuple

from typing import TYPE_CHECKING, Sequence, Tuple, Union

if TYPE_CHECKING:
from gfn.actions import Actions
Expand Down Expand Up @@ -92,20 +91,29 @@ def __init__(
if when_is_done is not None
else torch.full(size=(0,), fill_value=-1, dtype=torch.long)
)
assert self.when_is_done.shape == (self.n_trajectories,) and self.when_is_done.dtype == torch.long
assert (
self.when_is_done.shape == (self.n_trajectories,)
and self.when_is_done.dtype == torch.long
)

self._log_rewards = (
log_rewards
if log_rewards is not None
else torch.full(size=(0,), fill_value=0, dtype=torch.float)
)
assert self._log_rewards.shape == (self.n_trajectories,) and self._log_rewards.dtype == torch.float
assert (
self._log_rewards.shape == (self.n_trajectories,)
and self._log_rewards.dtype == torch.float
)

if log_probs is not None:
assert log_probs.shape == (self.max_length, self.n_trajectories) and log_probs.dtype == torch.float
assert (
log_probs.shape == (self.max_length, self.n_trajectories)
and log_probs.dtype == torch.float
)
else:
log_probs = torch.full(size=(0, 0), fill_value=0, dtype=torch.float)
self.log_probs = log_probs
self.log_probs = log_probs

self.estimator_outputs = estimator_outputs
if self.estimator_outputs is not None:
Expand Down Expand Up @@ -207,15 +215,13 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories:
)

@staticmethod
def extend_log_probs(
log_probs: torch.Tensor, new_max_length: int
) -> torch.Tensor:
def extend_log_probs(log_probs: torch.Tensor, new_max_length: int) -> torch.Tensor:
"""Extend the log_probs matrix by adding 0 until the required length is reached.
Args:
log_probs: The log_probs tensor of shape (max_length, n_trajectories) to extend.
new_max_length: The new length of the log_probs tensor.
Returns: The extended log_probs tensor of shape (new_max_length, n_trajectories).
"""
Expand Down
22 changes: 17 additions & 5 deletions src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ def __init__(
if is_done is not None
else torch.full(size=(0,), fill_value=False, dtype=torch.bool)
)
assert self.is_done.shape == (self.n_transitions,) and self.is_done.dtype == torch.bool
assert (
self.is_done.shape == (self.n_transitions,)
and self.is_done.dtype == torch.bool
)

self.next_states = (
next_states
Expand All @@ -96,9 +99,15 @@ def __init__(
and self.states.batch_shape == self.next_states.batch_shape
)
self._log_rewards = log_rewards if log_rewards is not None else torch.zeros(0)
assert self._log_rewards.shape == (self.n_transitions,) and self._log_rewards.dtype == torch.float
assert (
self._log_rewards.shape == (self.n_transitions,)
and self._log_rewards.dtype == torch.float
)
self.log_probs = log_probs if log_probs is not None else torch.zeros(0)
assert self.log_probs.shape == (self.n_transitions,) and self.log_probs.dtype == torch.float
assert (
self.log_probs.shape == (self.n_transitions,)
and self.log_probs.dtype == torch.float
)

@property
def n_transitions(self) -> int:
Expand Down Expand Up @@ -186,8 +195,11 @@ def all_log_rewards(self) -> torch.Tensor:
log_rewards[~is_sink_state, 1] = torch.log(
self.env.reward(self.next_states[~is_sink_state])
)

assert log_rewards.shape == (self.n_transitions, 2) and log_rewards.dtype == torch.float

assert (
log_rewards.shape == (self.n_transitions, 2)
and log_rewards.dtype == torch.float
)
return log_rewards

def __getitem__(self, index: int | Sequence[int]) -> Transitions:
Expand Down
21 changes: 13 additions & 8 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from abc import ABC, abstractmethod
from typing import Generic, Tuple, TypeVar, Union, Any
from typing import Any, Generic, Tuple, TypeVar, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -211,7 +211,6 @@ def get_pfs_and_pbs(
# Using all non-initial states, calculate the backward policy, and the logprobs
# of those actions.
if trajectories.conditioning is not None:

# We need to index the conditioning vector to broadcast over the states.
cond_dim = (-1,) * len(trajectories.conditioning.shape)
traj_len = trajectories.states.tensor.shape[0]
Expand Down Expand Up @@ -242,8 +241,14 @@ def get_pfs_and_pbs(
log_pb_trajectories_slice[~valid_actions.is_exit] = valid_log_pb_actions
log_pb_trajectories[~trajectories.actions.is_dummy] = log_pb_trajectories_slice

assert log_pf_trajectories.shape == (trajectories.max_length, trajectories.n_trajectories)
assert log_pb_trajectories.shape == (trajectories.max_length, trajectories.n_trajectories)
assert log_pf_trajectories.shape == (
trajectories.max_length,
trajectories.n_trajectories,
)
assert log_pb_trajectories.shape == (
trajectories.max_length,
trajectories.n_trajectories,
)
return log_pf_trajectories, log_pb_trajectories

def get_trajectories_scores(
Expand All @@ -252,15 +257,15 @@ def get_trajectories_scores(
recalculate_all_logprobs: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Given a batch of trajectories, calculate forward & backward policy scores.
Args:
trajectories: Trajectories to evaluate.
recalculate_all_logprobs: Whether to re-evaluate all logprobs.
Returns: A tuple of float tensors of shape (n_trajectories,)
containing the total log_pf, total log_pb, and the total
log-likelihood of the trajectories.
"""
log_pf_trajectories, log_pb_trajectories = self.get_pfs_and_pbs(
trajectories, recalculate_all_logprobs=recalculate_all_logprobs
Expand All @@ -279,7 +284,7 @@ def get_trajectories_scores(
torch.isinf(total_log_pb_trajectories)
):
raise ValueError("Infinite logprobs found")

assert total_log_pf_trajectories.shape == (trajectories.n_trajectories,)
assert total_log_pb_trajectories.shape == (trajectories.n_trajectories,)
return (
Expand Down
6 changes: 3 additions & 3 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from gfn.containers import Trajectories, Transitions
from gfn.env import Env
from gfn.gflownet.base import PFBasedGFlowNet
from gfn.modules import GFNModule, ScalarEstimator, ConditionalScalarEstimator
from gfn.modules import ConditionalScalarEstimator, GFNModule, ScalarEstimator
from gfn.utils.common import has_log_probs
from gfn.utils.handlers import (
has_conditioning_exception_handler,
Expand Down Expand Up @@ -91,7 +91,7 @@ def get_scores(
- If transitions have log_probs attribute, use them - this is usually for on-policy learning
- Else, re-evaluate the log_probs using the current self.pf - this is usually for
off-policy learning with replay buffer
Returns: A tuple of three tensors of shapes (n_transitions,), representing the
log probabilities of the actions, the log probabilities of the backward actions, and th scores.
Expand Down Expand Up @@ -194,7 +194,7 @@ def get_scores(

assert valid_log_pf_actions.shape == (transitions.n_transitions,)
assert log_pb_actions.shape == (transitions.n_transitions,)
assert scores.shape == (transitions.n_transitions,)
assert scores.shape == (transitions.n_transitions,)
return valid_log_pf_actions, log_pb_actions, scores

def loss(self, env: Env, transitions: Transitions) -> torch.Tensor:
Expand Down
11 changes: 6 additions & 5 deletions src/gfn/gflownet/flow_matching.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from typing import Tuple, Any, Union
from typing import Any, Tuple, Union

import torch

from gfn.containers import Trajectories
from gfn.env import Env
from gfn.gflownet.base import GFlowNet
from gfn.modules import DiscretePolicyEstimator, ConditionalDiscretePolicyEstimator
from gfn.modules import ConditionalDiscretePolicyEstimator, DiscretePolicyEstimator
from gfn.samplers import Sampler
from gfn.states import DiscreteStates, States
from gfn.utils.handlers import (
no_conditioning_exception_handler,
has_conditioning_exception_handler,
no_conditioning_exception_handler,
)


Expand Down Expand Up @@ -109,7 +109,6 @@ def flow_matching_loss(
)

if conditioning is not None:

# Mask out only valid conditioning elements.
valid_backward_conditioning = conditioning[valid_backward_mask]
valid_forward_conditioning = conditioning[valid_forward_mask]
Expand Down Expand Up @@ -204,7 +203,9 @@ def loss(
)
return fm_loss + self.alpha * rm_loss

def to_training_samples(self, trajectories: Trajectories) -> Union[
def to_training_samples(
self, trajectories: Trajectories
) -> Union[
Tuple[DiscreteStates, DiscreteStates, torch.Tensor, torch.Tensor],
Tuple[DiscreteStates, DiscreteStates, None, None],
Tuple[States, States, torch.Tensor, torch.Tensor],
Expand Down
27 changes: 14 additions & 13 deletions src/gfn/gflownet/sub_trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from gfn.containers import Trajectories
from gfn.env import Env
from gfn.gflownet.base import TrajectoryBasedGFlowNet
from gfn.modules import GFNModule, ScalarEstimator, ConditionalScalarEstimator
from gfn.modules import ConditionalScalarEstimator, GFNModule, ScalarEstimator
from gfn.utils.handlers import (
has_conditioning_exception_handler,
no_conditioning_exception_handler,
)


ContributionsTensor = torch.Tensor # shape: [max_len * (1 + max_len) / 2, n_trajectories]
ContributionsTensor = (
torch.Tensor
) # shape: [max_len * (1 + max_len) / 2, n_trajectories]
CumulativeLogProbsTensor = torch.Tensor # shape: [max_length + 1, n_trajectories]
LogStateFlowsTensor = torch.Tensor # shape: [max_length, n_trajectories]
LogTrajectoriesTensor = torch.Tensor # shape: [max_length, n_trajectories]
Expand Down Expand Up @@ -115,7 +116,7 @@ def cumulative_logprobs(
trajectories: a batch of trajectories.
log_p_trajectories: log probabilities of each transition in each trajectory.
Returns: Tensor of shape (max_length + 1, n_trajectories), containing the
Returns: Tensor of shape (max_length + 1, n_trajectories), containing the
cumulative sum of log probabilities of each trajectory.
"""
return torch.cat(
Expand All @@ -136,12 +137,12 @@ def calculate_preds(
) -> PredictionsTensor:
"""
Calculate the predictions tensor for the current sub-trajectory length.
Args:
log_pf_trajectories_cum: Tensor of shape (max_length + 1, n_trajectories) containing the cumulative log probabilities of the forward actions.
log_state_flows: Tensor of shape (max_length, n_trajectories) containing the log state flows.
i: The sub-trajectory length.
Returns: The predictions tensor of shape (max_length + 1 - i, n_trajectories).
"""
current_log_state_flows = (
Expand Down Expand Up @@ -179,7 +180,7 @@ def calculate_targets(
sink_states_mask: A mask tensor of shape (max_length, n_trajectories) representing sink states.
full_mask: A mask tensor of shape (max_length, n_trajectories) representing full states.
i: The sub-trajectory length.
Returns: The targets tensor of shape (max_length + 1 - i, n_trajectories).
"""
targets = torch.full_like(preds, fill_value=-float("inf"))
Expand Down Expand Up @@ -262,7 +263,7 @@ def calculate_masks(
Args:
log_state_flows: Tensor of shape (max_length, n_trajectories) containing the log state flows.
trajectories: The trajectories data.
Returns: a tuple of three mask tensors of shape (max_length, n_trajectories).
"""
sink_states_mask = log_state_flows == -float("inf")
Expand Down Expand Up @@ -353,7 +354,7 @@ def get_equal_within_contributions(
Args:
trajectories: The trajectories data.
all_scores: The scores tensor.
Returns: The contributions tensor of shape (max_len * (1 + max_len) / 2, n_trajectories).
"""
del all_scores
Expand Down Expand Up @@ -383,7 +384,7 @@ def get_equal_contributions(
Args:
trajectories: The trajectories data.
all_scores: The scores tensor.
Returns: The contributions tensor of shape (max_len * (1 + max_len) / 2, n_trajectories).
"""
is_done = trajectories.when_is_done
Expand All @@ -402,7 +403,7 @@ def get_tb_contributions(
Args:
trajectories: The trajectories data.
all_scores: The scores tensor.
Returns: The contributions tensor of shape (max_len * (1 + max_len) / 2, n_trajectories).
"""
max_len = trajectories.max_length
Expand All @@ -427,7 +428,7 @@ def get_modified_db_contributions(
Args:
trajectories: The trajectories data.
all_scores: The scores tensor.
Returns: The contributions tensor of shape (max_len * (1 + max_len) / 2, n_trajectories).
"""
del all_scores
Expand Down Expand Up @@ -461,7 +462,7 @@ def get_geometric_within_contributions(
Args:
trajectories: The trajectories data.
all_scores: The scores tensor.
Returns: The contributions tensor of shape (max_len * (1 + max_len) / 2, n_trajectories).
"""
del all_scores
Expand Down
Loading

0 comments on commit 0d593d8

Please sign in to comment.