Skip to content

Commit

Permalink
added default value for log_reward_clip_min in abstract base class
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Feb 16, 2024
1 parent ebfd4c8 commit 8258b19
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class GFlowNet(ABC, nn.Module, Generic[TrainingSampleType]):
A formal definition of GFlowNets is given in Sec. 3 of [GFlowNet Foundations](https://arxiv.org/pdf/2111.09266).
"""
log_reward_clip_min = float("-inf") # Default off.

@abstractmethod
def sample_trajectories(
Expand Down Expand Up @@ -214,7 +215,7 @@ def get_trajectories_scores(
total_log_pb_trajectories = log_pb_trajectories.sum(dim=0)

log_rewards = trajectories.log_rewards
# TODO: log_reward_clip_min isn't defined in base (#155).

if math.isfinite(self.log_reward_clip_min) and log_rewards is not None:
log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)

Expand Down

0 comments on commit 8258b19

Please sign in to comment.