diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 5e04151d..e38bb10a 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -24,6 +24,7 @@ class GFlowNet(ABC, nn.Module, Generic[TrainingSampleType]): A formal definition of GFlowNets is given in Sec. 3 of [GFlowNet Foundations](https://arxiv.org/pdf/2111.09266). """ + log_reward_clip_min = float("-inf") # Default off. @abstractmethod def sample_trajectories( @@ -214,7 +215,7 @@ def get_trajectories_scores( total_log_pb_trajectories = log_pb_trajectories.sum(dim=0) log_rewards = trajectories.log_rewards - # TODO: log_reward_clip_min isn't defined in base (#155). + if math.isfinite(self.log_reward_clip_min) and log_rewards is not None: log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)