Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Notebook fix #187

Merged
merged 3 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,14 @@ class PrioritizedReplayBuffer(ReplayBuffer):
p_norm_distance: p-norm distance value to pass to torch.cdist, for the
determination of novel states.
"""

def __init__(
self,
env: Env,
objects_type: Literal["transitions", "trajectories", "states"],
capacity: int = 1000,
cutoff_distance: float = 0.,
p_norm_distance: float = 1.,
cutoff_distance: float = 0.0,
p_norm_distance: float = 1.0,
):
"""Instantiates a prioritized replay buffer.
Args:
Expand All @@ -137,7 +138,7 @@ def __init__(
norms are >= 0).
p_norm_distance: p-norm distance value to pass to torch.cdist, for the
determination of novel states.
"""
"""
super().__init__(env, objects_type, capacity)
self.cutoff_distance = cutoff_distance
self.p_norm_distance = p_norm_distance
Expand Down
12 changes: 10 additions & 2 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,21 @@ def logF_named_parameters(self):
try:
return {k: v for k, v in self.named_parameters() if "logF" in k}
except KeyError as e:
print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e))
print(
"logF not found in self.named_parameters. Are the weights tied with PF? {}".format(
e
)
)

def logF_parameters(self):
try:
return [v for k, v in self.named_parameters() if "logF" in k]
except KeyError as e:
print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e))
print(
"logF not found in self.named_parameters. Are the weights tied with PF? {}".format(
e
)
)

def get_scores(
self, env: Env, transitions: Transitions, recalculate_all_logprobs: bool = False
Expand Down
4 changes: 3 additions & 1 deletion src/gfn/gflownet/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ class FMGFlowNet(GFlowNet[Tuple[DiscreteStates, DiscreteStates]]):
def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0):
super().__init__()

assert isinstance(logF, DiscretePolicyEstimator), "logF must be a Discrete Policy Estimator"
assert isinstance(
logF, DiscretePolicyEstimator
), "logF must be a Discrete Policy Estimator"
self.logF = logF
self.alpha = alpha

Expand Down
12 changes: 10 additions & 2 deletions src/gfn/gflownet/sub_trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,21 @@ def logF_named_parameters(self):
try:
return {k: v for k, v in self.named_parameters() if "logF" in k}
except KeyError as e:
print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e))
print(
"logF not found in self.named_parameters. Are the weights tied with PF? {}".format(
e
)
)

def logF_parameters(self):
try:
return [v for k, v in self.named_parameters() if "logF" in k]
except KeyError as e:
print("logF not found in self.named_parameters. Are the weights tied with PF? {}".format(e))
print(
"logF not found in self.named_parameters. Are the weights tied with PF? {}".format(
e
)
)

def cumulative_logprobs(
self,
Expand Down
4 changes: 3 additions & 1 deletion src/gfn/gflownet/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def __init__(
if isinstance(logZ, float):
self.logZ = nn.Parameter(torch.tensor(logZ))
else:
assert isinstance(logZ, ScalarEstimator), "logZ must be either float or a ScalarEstimator"
assert isinstance(
logZ, ScalarEstimator
), "logZ must be either float or a ScalarEstimator"
self.logZ = logZ

self.log_reward_clip_min = log_reward_clip_min
Expand Down
2,243 changes: 1,388 additions & 855 deletions tutorials/notebooks/intro_gfn_continuous_line_simple.ipynb

Large diffs are not rendered by default.

790 changes: 434 additions & 356 deletions tutorials/notebooks/intro_gfn_smiley.ipynb

Large diffs are not rendered by default.

Loading