From ca2c69846b5093716a8dd7ac8407fe4680fe4308 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 30 Aug 2024 07:40:33 -0400 Subject: [PATCH] loss reduction options --- src/gfn/gflownet/base.py | 18 +++++++++++++++++- src/gfn/gflownet/detailed_balance.py | 10 ++++++---- src/gfn/gflownet/flow_matching.py | 18 +++++++++++------- src/gfn/gflownet/sub_trajectory_balance.py | 14 +++++++++----- src/gfn/gflownet/trajectory_balance.py | 10 +++++++--- 5 files changed, 50 insertions(+), 20 deletions(-) diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 032639a2..b5a63727 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -20,6 +20,16 @@ ) +def loss_reduce(loss, method): + """Utility function to handle loss aggregation strategies.""" + if method == "mean": + return torch.mean(loss) + elif method == "sum": + return torch.sum(loss) + elif method == "none": + return loss + + class GFlowNet(ABC, nn.Module, Generic[TrainingSampleType]): """Abstract Base Class for GFlowNets. @@ -73,7 +83,7 @@ def to_training_samples(self, trajectories: Trajectories) -> TrainingSampleType: """Converts trajectories to training samples. The type depends on the GFlowNet.""" @abstractmethod - def loss(self, env: Env, training_objects): + def loss(self, env: Env, training_objects, reduction: str): """Computes the loss given the training objects.""" @@ -116,6 +126,12 @@ def pf_pb_named_parameters(self): def pf_pb_parameters(self): return [v for k, v in self.named_parameters() if "pb" in k or "pf" in k] + def logF_named_parameters(self): + return {k: v for k, v in self.named_parameters() if "logF" in k} + + def logF_parameters(self): + return [v for k, v in self.named_parameters() if "logF" in k] + class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]): def get_pfs_and_pbs( diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index 3d97b1ad..b7019d36 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -6,7 +6,7 @@ from gfn.containers import Trajectories, Transitions from gfn.env import Env -from gfn.gflownet.base import PFBasedGFlowNet +from gfn.gflownet.base import PFBasedGFlowNet, loss_reduce from gfn.modules import GFNModule, ScalarEstimator from gfn.utils.common import has_log_probs @@ -130,13 +130,14 @@ def get_scores( return (valid_log_pf_actions, log_pb_actions, scores) - def loss(self, env: Env, transitions: Transitions) -> TT[0, float]: + def loss(self, env: Env, transitions: Transitions, reduction: str = "mean") -> TT[0, float]: """Detailed balance loss. The detailed balance loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).""" _, _, scores = self.get_scores(env, transitions) - loss = torch.mean(scores**2) + scores = scores**2 + loss = loss_reduce(scores, reduction) if torch.isnan(loss): raise ValueError("loss is nan") @@ -215,7 +216,8 @@ def get_scores( def loss(self, env: Env, transitions: Transitions) -> TT[0, float]: """Calculates the modified detailed balance loss.""" scores = self.get_scores(transitions) - return torch.mean(scores**2) + scores = scores**2 + return loss_reduce(loss, reduction) def to_training_samples(self, trajectories: Trajectories) -> Transitions: return trajectories.to_transitions() diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 5764cb8e..4917c973 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -5,7 +5,7 @@ from gfn.containers import Trajectories from gfn.env import Env -from gfn.gflownet.base import GFlowNet +from gfn.gflownet.base import GFlowNet, loss_reduce from gfn.modules import DiscretePolicyEstimator from gfn.samplers import Sampler from gfn.states import DiscreteStates @@ -60,6 +60,7 @@ def flow_matching_loss( self, env: Env, states: DiscreteStates, + reduction: str = "mean", ) -> TT["n_trajectories", torch.float]: """Computes the FM for the provided states. @@ -113,11 +114,12 @@ def flow_matching_loss( log_incoming_flows = torch.logsumexp(incoming_log_flows, dim=-1) log_outgoing_flows = torch.logsumexp(outgoing_log_flows, dim=-1) + scores = (log_incoming_flows - log_outgoing_flows).pow(2) - return (log_incoming_flows - log_outgoing_flows).pow(2).mean() + return loss_reduce(scores, reduction) def reward_matching_loss( - self, env: Env, terminating_states: DiscreteStates + self, env: Env, terminating_states: DiscreteStates,reduction: str = "mean" ) -> TT[0, float]: """Calculates the reward matching loss from the terminating states.""" del env # Unused @@ -127,10 +129,12 @@ def reward_matching_loss( # Handle the boundary condition (for all x, F(X->S_f) = R(x)). terminating_log_edge_flows = log_edge_flows[:, -1] log_rewards = terminating_states.log_rewards - return (terminating_log_edge_flows - log_rewards).pow(2).mean() + scores = (terminating_log_edge_flows - log_rewards).pow(2) + + return loss_reduce(scores, reduction) def loss( - self, env: Env, states_tuple: Tuple[DiscreteStates, DiscreteStates] + self, env: Env, states_tuple: Tuple[DiscreteStates, DiscreteStates], reduction: str = "mean" ) -> TT[0, float]: """Given a batch of non-terminal and terminal states, compute a loss. @@ -139,8 +143,8 @@ def loss( (i.e. non-terminal states), and the second one being the terminal states of the trajectories.""" intermediary_states, terminating_states = states_tuple - fm_loss = self.flow_matching_loss(env, intermediary_states) - rm_loss = self.reward_matching_loss(env, terminating_states) + fm_loss = self.flow_matching_loss(env, intermediary_states, reduction=reduction) + rm_loss = self.reward_matching_loss(env, terminating_states, reduction=reduction) return fm_loss + self.alpha * rm_loss def to_training_samples( diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index 6e8b1324..1d2107cb 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -6,7 +6,7 @@ from gfn.containers import Trajectories from gfn.env import Env -from gfn.gflownet.base import TrajectoryBasedGFlowNet +from gfn.gflownet.base import TrajectoryBasedGFlowNet, loss_reduce from gfn.modules import GFNModule, ScalarEstimator ContributionsTensor = TT["max_len * (1 + max_len) / 2", "n_trajectories"] @@ -382,15 +382,17 @@ def get_geometric_within_contributions( return contributions - def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]: + def loss(self, env: Env, trajectories: Trajectories, reduction: str = "mean") -> TT[0, float]: # Get all scores and masks from the trajectories. scores, flattening_masks = self.get_scores(env, trajectories) flattening_mask = torch.cat(flattening_masks) all_scores = torch.cat(scores, 0) if self.weighting == "DB": - # Longer trajectories contribute more to the loss - return scores[0][~flattening_masks[0]].pow(2).mean() + # Longer trajectories contribute more to the loss. + # TODO: is this correct with `loss_reduce`? + final_scores = scores[0][~flattening_masks[0]].pow(2) + return loss_reduce(final_scores, reduction) elif self.weighting == "geometric": # The position i of the following 1D tensor represents the number of sub- @@ -440,4 +442,6 @@ def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]: flat_contributions.sum() - 1.0 ).abs() < 1e-5, f"{flat_contributions.sum()}" losses = flat_contributions * all_scores[~flattening_mask].pow(2) - return losses.sum() + + # TODO: default was sum, does this even work with mean? + return loss_reduce(final_scores, reduction) diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index 1f8799d9..9be42e58 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -9,7 +9,7 @@ from gfn.containers import Trajectories from gfn.env import Env -from gfn.gflownet.base import TrajectoryBasedGFlowNet +from gfn.gflownet.base import TrajectoryBasedGFlowNet, loss_reduce from gfn.modules import GFNModule @@ -46,6 +46,7 @@ def loss( env: Env, trajectories: Trajectories, recalculate_all_logprobs: bool = False, + reduction: str = "mean", ) -> TT[0, float]: """Trajectory balance loss. @@ -59,7 +60,8 @@ def loss( _, _, scores = self.get_trajectories_scores( trajectories, recalculate_all_logprobs=recalculate_all_logprobs ) - loss = (scores + self.logZ).pow(2).mean() + scores = (scores + self.logZ).pow(2) + loss = loss_reduce(scores, reduction) if torch.isnan(loss): raise ValueError("loss is nan") @@ -90,6 +92,7 @@ def loss( env: Env, trajectories: Trajectories, recalculate_all_logprobs: bool = False, + reduction: str = "mean", ) -> TT[0, float]: """Log Partition Variance loss. @@ -100,7 +103,8 @@ def loss( _, _, scores = self.get_trajectories_scores( trajectories, recalculate_all_logprobs=recalculate_all_logprobs ) - loss = (scores - scores.mean()).pow(2).mean() + scores = (scores - scores.mean()).pow(2) + loss = loss_reduce(scores, reduction) if torch.isnan(loss): raise ValueError("loss is NaN.")