Skip to content

Commit

Permalink
change recalculate_all to recalculate_all_logprobs
Browse files Browse the repository at this point in the history
  • Loading branch information
saleml committed Apr 2, 2024
1 parent 368af4c commit bff764d
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 18 deletions.
12 changes: 6 additions & 6 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_pfs_and_pbs(
self,
trajectories: Trajectories,
fill_value: float = 0.0,
recalculate_all: bool = False,
recalculate_all_logprobs: bool = False,
) -> Tuple[
TT["max_length", "n_trajectories", torch.float],
TT["max_length", "n_trajectories", torch.float],
Expand All @@ -132,7 +132,7 @@ def get_pfs_and_pbs(
More specifically it evaluates $\log P_F (s' \mid s)$ and $\log P_B(s \mid s')$
for each transition in each trajectory in the batch.
Unless recalculate_all=True, in which case we re-evaluate the logprobs of the trajectories with
Unless recalculate_all_logprobs=True, in which case we re-evaluate the logprobs of the trajectories with
the current self.pf. The following applies:
- If trajectories have log_probs attribute, use them - this is usually for on-policy learning
- Else, if trajectories have estimator_outputs attribute, transform them
Expand Down Expand Up @@ -165,10 +165,10 @@ def get_pfs_and_pbs(
if valid_states.batch_shape != tuple(valid_actions.batch_shape):
raise AssertionError("Something wrong happening with log_pf evaluations")

if has_log_probs(trajectories) and not recalculate_all:
if has_log_probs(trajectories) and not recalculate_all_logprobs:
log_pf_trajectories = trajectories.log_probs
else:
if trajectories.estimator_outputs is not None and not recalculate_all:
if trajectories.estimator_outputs is not None and not recalculate_all_logprobs:
estimator_outputs = trajectories.estimator_outputs[
~trajectories.actions.is_dummy
]
Expand Down Expand Up @@ -214,15 +214,15 @@ def get_pfs_and_pbs(
def get_trajectories_scores(
self,
trajectories: Trajectories,
recalculate_all: bool = False,
recalculate_all_logprobs: bool = False,
) -> Tuple[
TT["n_trajectories", torch.float],
TT["n_trajectories", torch.float],
TT["n_trajectories", torch.float],
]:
"""Given a batch of trajectories, calculate forward & backward policy scores."""
log_pf_trajectories, log_pb_trajectories = self.get_pfs_and_pbs(
trajectories, recalculate_all=recalculate_all
trajectories, recalculate_all_logprobs=recalculate_all_logprobs
)

assert log_pf_trajectories is not None
Expand Down
12 changes: 6 additions & 6 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
self.log_reward_clip_min = log_reward_clip_min

def get_scores(
self, env: Env, transitions: Transitions, recalculate_all: bool = False
self, env: Env, transitions: Transitions, recalculate_all_logprobs: bool = False
) -> Tuple[
TT["n_transitions", float],
TT["n_transitions", float],
Expand All @@ -53,7 +53,7 @@ def get_scores(
Args:
transitions: a batch of transitions.
Unless recalculate_all=True, in which case we re-evaluate the logprobs of the transitions with
Unless recalculate_all_logprobs=True, in which case we re-evaluate the logprobs of the transitions with
the current self.pf. The following applies:
- If transitions have log_probs attribute, use them - this is usually for on-policy learning
- Else, re-evaluate the log_probs using the current self.pf - this is usually for
Expand All @@ -74,7 +74,7 @@ def get_scores(
if states.batch_shape != tuple(actions.batch_shape):
raise ValueError("Something wrong happening with log_pf evaluations")

if has_log_probs(transitions) and not recalculate_all:
if has_log_probs(transitions) and not recalculate_all_logprobs:
valid_log_pf_actions = transitions.log_probs
else:
# Evaluate the log PF of the actions
Expand Down Expand Up @@ -156,11 +156,11 @@ class ModifiedDBGFlowNet(PFBasedGFlowNet[Transitions]):
"""

def get_scores(
self, transitions: Transitions, recalculate_all: bool = False
self, transitions: Transitions, recalculate_all_logprobs: bool = False
) -> TT["n_trajectories", torch.float]:
"""DAG-GFN-style detailed balance, when all states are connected to the sink.
Unless recalculate_all=True, in which case we re-evaluate the logprobs of the transitions with
Unless recalculate_all_logprobs=True, in which case we re-evaluate the logprobs of the transitions with
the current self.pf. The following applies:
- If transitions have log_probs attribute, use them - this is usually for on-policy learning
- Else, re-evaluate the log_probs using the current self.pf - this is usually for
Expand All @@ -181,7 +181,7 @@ def get_scores(
module_output = self.pf(states)
pf_dist = self.pf.to_probability_distribution(states, module_output)

if has_log_probs(transitions) and not recalculate_all:
if has_log_probs(transitions) and not recalculate_all_logprobs:
valid_log_pf_actions = transitions[mask].log_probs
else:
# Evaluate the log PF of the actions sampled off policy.
Expand Down
8 changes: 4 additions & 4 deletions src/gfn/gflownet/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
self.log_reward_clip_min = log_reward_clip_min

def loss(
self, env: Env, trajectories: Trajectories, recalculate_all: bool = False
self, env: Env, trajectories: Trajectories, recalculate_all_logprobs: bool = False
) -> TT[0, float]:
"""Trajectory balance loss.
Expand All @@ -54,7 +54,7 @@ def loss(
"""
del env # unused
_, _, scores = self.get_trajectories_scores(
trajectories, recalculate_all=recalculate_all
trajectories, recalculate_all_logprobs=recalculate_all_logprobs
)
loss = (scores + self.logZ).pow(2).mean()
if torch.isnan(loss):
Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(
self.log_reward_clip_min = log_reward_clip_min

def loss(
self, env: Env, trajectories: Trajectories, recalculate_all: bool = False
self, env: Env, trajectories: Trajectories, recalculate_all_logprobs: bool = False
) -> TT[0, float]:
"""Log Partition Variance loss.
Expand All @@ -92,7 +92,7 @@ def loss(
"""
del env # unused
_, _, scores = self.get_trajectories_scores(
trajectories, recalculate_all=recalculate_all
trajectories, recalculate_all_logprobs=recalculate_all_logprobs
)
loss = (scores - scores.mean()).pow(2).mean()
if torch.isnan(loss):
Expand Down
4 changes: 2 additions & 2 deletions testing/test_parametrizations_and_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_get_pfs_and_pbs(env_name: str, preprocessor_name: str):

log_pfs_on, log_pbs_on = gflownet_on.get_pfs_and_pbs(trajectories)
log_pfs_off, log_pbs_off = gflownet_off.get_pfs_and_pbs(
trajectories, recalculate_all=True
trajectories, recalculate_all_logprobs=True
)


Expand All @@ -92,7 +92,7 @@ def test_get_scores(env_name: str, preprocessor_name: str):
gflownet_off = TBGFlowNet(pf=pf_estimator, pb=pb_estimator)
scores_on = gflownet_on.get_trajectories_scores(trajectories)
scores_off = gflownet_off.get_trajectories_scores(
trajectories, recalculate_all=True
trajectories, recalculate_all_logprobs=True
)
assert all(
[
Expand Down

0 comments on commit bff764d

Please sign in to comment.