Skip to content

Commit

Permalink
added helper methods and type checking for logZ, including allowing t…
Browse files Browse the repository at this point in the history
…he user to have a conditional logZ
  • Loading branch information
josephdviviano committed Sep 20, 2024
1 parent 83f276a commit 54af465
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions src/gfn/gflownet/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from gfn.containers import Trajectories
from gfn.env import Env
from gfn.gflownet.base import TrajectoryBasedGFlowNet
from gfn.modules import GFNModule
from gfn.modules import GFNModule, ScalarEstimator


class TBGFlowNet(TrajectoryBasedGFlowNet):
Expand All @@ -23,22 +23,25 @@ class TBGFlowNet(TrajectoryBasedGFlowNet):
the DAG, or a singleton thereof, if self.logit_PB is a fixed DiscretePBEstimator.
Attributes:
logZ: a LogZEstimator instance.
logZ: a ScalarEstimator (for conditional GFNs) instance, or float.
log_reward_clip_min: If finite, clips log rewards to this value.
"""

def __init__(
self,
pf: GFNModule,
pb: GFNModule,
init_logZ: float = 0.0,
logZ: float | ScalarEstimator = 0.0,
log_reward_clip_min: float = -float("inf"),
):
super().__init__(pf, pb)

self.logZ = nn.Parameter(
torch.tensor(init_logZ)
) # TODO: Optionally, this should be a nn.Module to support conditional GFNs.
if isinstance(logZ, float):
self.logZ = nn.Parameter(torch.tensor(logZ))
else:
assert isinstance(logZ, ScalarEstimator), "logZ must be either float or a ScalarEstimator"
self.logZ = logZ

self.log_reward_clip_min = log_reward_clip_min

def loss(
Expand Down

0 comments on commit 54af465

Please sign in to comment.