Skip to content

Commit

Permalink
loss reduction options
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Aug 30, 2024
1 parent cedc9b0 commit ca2c698
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 20 deletions.
18 changes: 17 additions & 1 deletion src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@
)


def loss_reduce(loss, method):
"""Utility function to handle loss aggregation strategies."""
if method == "mean":
return torch.mean(loss)
elif method == "sum":
return torch.sum(loss)
elif method == "none":
return loss


class GFlowNet(ABC, nn.Module, Generic[TrainingSampleType]):
"""Abstract Base Class for GFlowNets.
Expand Down Expand Up @@ -73,7 +83,7 @@ def to_training_samples(self, trajectories: Trajectories) -> TrainingSampleType:
"""Converts trajectories to training samples. The type depends on the GFlowNet."""

@abstractmethod
def loss(self, env: Env, training_objects):
def loss(self, env: Env, training_objects, reduction: str):
"""Computes the loss given the training objects."""


Expand Down Expand Up @@ -116,6 +126,12 @@ def pf_pb_named_parameters(self):
def pf_pb_parameters(self):
return [v for k, v in self.named_parameters() if "pb" in k or "pf" in k]

def logF_named_parameters(self):
return {k: v for k, v in self.named_parameters() if "logF" in k}

def logF_parameters(self):
return [v for k, v in self.named_parameters() if "logF" in k]


class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]):
def get_pfs_and_pbs(
Expand Down
10 changes: 6 additions & 4 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from gfn.containers import Trajectories, Transitions
from gfn.env import Env
from gfn.gflownet.base import PFBasedGFlowNet
from gfn.gflownet.base import PFBasedGFlowNet, loss_reduce
from gfn.modules import GFNModule, ScalarEstimator
from gfn.utils.common import has_log_probs

Expand Down Expand Up @@ -130,13 +130,14 @@ def get_scores(

return (valid_log_pf_actions, log_pb_actions, scores)

def loss(self, env: Env, transitions: Transitions) -> TT[0, float]:
def loss(self, env: Env, transitions: Transitions, reduction: str = "mean") -> TT[0, float]:
"""Detailed balance loss.
The detailed balance loss is described in section
3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266)."""
_, _, scores = self.get_scores(env, transitions)
loss = torch.mean(scores**2)
scores = scores**2
loss = loss_reduce(scores, reduction)

if torch.isnan(loss):
raise ValueError("loss is nan")
Expand Down Expand Up @@ -215,7 +216,8 @@ def get_scores(
def loss(self, env: Env, transitions: Transitions) -> TT[0, float]:
"""Calculates the modified detailed balance loss."""
scores = self.get_scores(transitions)
return torch.mean(scores**2)
scores = scores**2
return loss_reduce(loss, reduction)

def to_training_samples(self, trajectories: Trajectories) -> Transitions:
return trajectories.to_transitions()
18 changes: 11 additions & 7 deletions src/gfn/gflownet/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from gfn.containers import Trajectories
from gfn.env import Env
from gfn.gflownet.base import GFlowNet
from gfn.gflownet.base import GFlowNet, loss_reduce
from gfn.modules import DiscretePolicyEstimator
from gfn.samplers import Sampler
from gfn.states import DiscreteStates
Expand Down Expand Up @@ -60,6 +60,7 @@ def flow_matching_loss(
self,
env: Env,
states: DiscreteStates,
reduction: str = "mean",
) -> TT["n_trajectories", torch.float]:
"""Computes the FM for the provided states.
Expand Down Expand Up @@ -113,11 +114,12 @@ def flow_matching_loss(

log_incoming_flows = torch.logsumexp(incoming_log_flows, dim=-1)
log_outgoing_flows = torch.logsumexp(outgoing_log_flows, dim=-1)
scores = (log_incoming_flows - log_outgoing_flows).pow(2)

return (log_incoming_flows - log_outgoing_flows).pow(2).mean()
return loss_reduce(scores, reduction)

def reward_matching_loss(
self, env: Env, terminating_states: DiscreteStates
self, env: Env, terminating_states: DiscreteStates,reduction: str = "mean"
) -> TT[0, float]:
"""Calculates the reward matching loss from the terminating states."""
del env # Unused
Expand All @@ -127,10 +129,12 @@ def reward_matching_loss(
# Handle the boundary condition (for all x, F(X->S_f) = R(x)).
terminating_log_edge_flows = log_edge_flows[:, -1]
log_rewards = terminating_states.log_rewards
return (terminating_log_edge_flows - log_rewards).pow(2).mean()
scores = (terminating_log_edge_flows - log_rewards).pow(2)

return loss_reduce(scores, reduction)

def loss(
self, env: Env, states_tuple: Tuple[DiscreteStates, DiscreteStates]
self, env: Env, states_tuple: Tuple[DiscreteStates, DiscreteStates], reduction: str = "mean"
) -> TT[0, float]:
"""Given a batch of non-terminal and terminal states, compute a loss.
Expand All @@ -139,8 +143,8 @@ def loss(
(i.e. non-terminal states), and the second one being the terminal states of the
trajectories."""
intermediary_states, terminating_states = states_tuple
fm_loss = self.flow_matching_loss(env, intermediary_states)
rm_loss = self.reward_matching_loss(env, terminating_states)
fm_loss = self.flow_matching_loss(env, intermediary_states, reduction=reduction)
rm_loss = self.reward_matching_loss(env, terminating_states, reduction=reduction)
return fm_loss + self.alpha * rm_loss

def to_training_samples(
Expand Down
14 changes: 9 additions & 5 deletions src/gfn/gflownet/sub_trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from gfn.containers import Trajectories
from gfn.env import Env
from gfn.gflownet.base import TrajectoryBasedGFlowNet
from gfn.gflownet.base import TrajectoryBasedGFlowNet, loss_reduce
from gfn.modules import GFNModule, ScalarEstimator

ContributionsTensor = TT["max_len * (1 + max_len) / 2", "n_trajectories"]
Expand Down Expand Up @@ -382,15 +382,17 @@ def get_geometric_within_contributions(

return contributions

def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]:
def loss(self, env: Env, trajectories: Trajectories, reduction: str = "mean") -> TT[0, float]:
# Get all scores and masks from the trajectories.
scores, flattening_masks = self.get_scores(env, trajectories)
flattening_mask = torch.cat(flattening_masks)
all_scores = torch.cat(scores, 0)

if self.weighting == "DB":
# Longer trajectories contribute more to the loss
return scores[0][~flattening_masks[0]].pow(2).mean()
# Longer trajectories contribute more to the loss.
# TODO: is this correct with `loss_reduce`?
final_scores = scores[0][~flattening_masks[0]].pow(2)
return loss_reduce(final_scores, reduction)

elif self.weighting == "geometric":
# The position i of the following 1D tensor represents the number of sub-
Expand Down Expand Up @@ -440,4 +442,6 @@ def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]:
flat_contributions.sum() - 1.0
).abs() < 1e-5, f"{flat_contributions.sum()}"
losses = flat_contributions * all_scores[~flattening_mask].pow(2)
return losses.sum()

# TODO: default was sum, does this even work with mean?
return loss_reduce(final_scores, reduction)
10 changes: 7 additions & 3 deletions src/gfn/gflownet/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from gfn.containers import Trajectories
from gfn.env import Env
from gfn.gflownet.base import TrajectoryBasedGFlowNet
from gfn.gflownet.base import TrajectoryBasedGFlowNet, loss_reduce
from gfn.modules import GFNModule


Expand Down Expand Up @@ -46,6 +46,7 @@ def loss(
env: Env,
trajectories: Trajectories,
recalculate_all_logprobs: bool = False,
reduction: str = "mean",
) -> TT[0, float]:
"""Trajectory balance loss.
Expand All @@ -59,7 +60,8 @@ def loss(
_, _, scores = self.get_trajectories_scores(
trajectories, recalculate_all_logprobs=recalculate_all_logprobs
)
loss = (scores + self.logZ).pow(2).mean()
scores = (scores + self.logZ).pow(2)
loss = loss_reduce(scores, reduction)
if torch.isnan(loss):
raise ValueError("loss is nan")

Expand Down Expand Up @@ -90,6 +92,7 @@ def loss(
env: Env,
trajectories: Trajectories,
recalculate_all_logprobs: bool = False,
reduction: str = "mean",
) -> TT[0, float]:
"""Log Partition Variance loss.
Expand All @@ -100,7 +103,8 @@ def loss(
_, _, scores = self.get_trajectories_scores(
trajectories, recalculate_all_logprobs=recalculate_all_logprobs
)
loss = (scores - scores.mean()).pow(2).mean()
scores = (scores - scores.mean()).pow(2)
loss = loss_reduce(scores, reduction)
if torch.isnan(loss):
raise ValueError("loss is NaN.")

Expand Down

0 comments on commit ca2c698

Please sign in to comment.