diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 8bb15fd9..2711e09e 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -122,7 +122,7 @@ def get_pfs_and_pbs( self, trajectories: Trajectories, fill_value: float = 0.0, - recalculate_all: bool = False, + recalculate_all_logprobs: bool = False, ) -> Tuple[ TT["max_length", "n_trajectories", torch.float], TT["max_length", "n_trajectories", torch.float], @@ -132,7 +132,7 @@ def get_pfs_and_pbs( More specifically it evaluates $\log P_F (s' \mid s)$ and $\log P_B(s \mid s')$ for each transition in each trajectory in the batch. - Unless recalculate_all=True, in which case we re-evaluate the logprobs of the trajectories with + Unless recalculate_all_logprobs=True, in which case we re-evaluate the logprobs of the trajectories with the current self.pf. The following applies: - If trajectories have log_probs attribute, use them - this is usually for on-policy learning - Else, if trajectories have estimator_outputs attribute, transform them @@ -165,10 +165,10 @@ def get_pfs_and_pbs( if valid_states.batch_shape != tuple(valid_actions.batch_shape): raise AssertionError("Something wrong happening with log_pf evaluations") - if has_log_probs(trajectories) and not recalculate_all: + if has_log_probs(trajectories) and not recalculate_all_logprobs: log_pf_trajectories = trajectories.log_probs else: - if trajectories.estimator_outputs is not None and not recalculate_all: + if trajectories.estimator_outputs is not None and not recalculate_all_logprobs: estimator_outputs = trajectories.estimator_outputs[ ~trajectories.actions.is_dummy ] @@ -214,7 +214,7 @@ def get_pfs_and_pbs( def get_trajectories_scores( self, trajectories: Trajectories, - recalculate_all: bool = False, + recalculate_all_logprobs: bool = False, ) -> Tuple[ TT["n_trajectories", torch.float], TT["n_trajectories", torch.float], @@ -222,7 +222,7 @@ def get_trajectories_scores( ]: """Given a batch of trajectories, calculate forward & backward policy scores.""" log_pf_trajectories, log_pb_trajectories = self.get_pfs_and_pbs( - trajectories, recalculate_all=recalculate_all + trajectories, recalculate_all_logprobs=recalculate_all_logprobs ) assert log_pf_trajectories is not None diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index 28cdd175..3d97b1ad 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -42,7 +42,7 @@ def __init__( self.log_reward_clip_min = log_reward_clip_min def get_scores( - self, env: Env, transitions: Transitions, recalculate_all: bool = False + self, env: Env, transitions: Transitions, recalculate_all_logprobs: bool = False ) -> Tuple[ TT["n_transitions", float], TT["n_transitions", float], @@ -53,7 +53,7 @@ def get_scores( Args: transitions: a batch of transitions. - Unless recalculate_all=True, in which case we re-evaluate the logprobs of the transitions with + Unless recalculate_all_logprobs=True, in which case we re-evaluate the logprobs of the transitions with the current self.pf. The following applies: - If transitions have log_probs attribute, use them - this is usually for on-policy learning - Else, re-evaluate the log_probs using the current self.pf - this is usually for @@ -74,7 +74,7 @@ def get_scores( if states.batch_shape != tuple(actions.batch_shape): raise ValueError("Something wrong happening with log_pf evaluations") - if has_log_probs(transitions) and not recalculate_all: + if has_log_probs(transitions) and not recalculate_all_logprobs: valid_log_pf_actions = transitions.log_probs else: # Evaluate the log PF of the actions @@ -156,11 +156,11 @@ class ModifiedDBGFlowNet(PFBasedGFlowNet[Transitions]): """ def get_scores( - self, transitions: Transitions, recalculate_all: bool = False + self, transitions: Transitions, recalculate_all_logprobs: bool = False ) -> TT["n_trajectories", torch.float]: """DAG-GFN-style detailed balance, when all states are connected to the sink. - Unless recalculate_all=True, in which case we re-evaluate the logprobs of the transitions with + Unless recalculate_all_logprobs=True, in which case we re-evaluate the logprobs of the transitions with the current self.pf. The following applies: - If transitions have log_probs attribute, use them - this is usually for on-policy learning - Else, re-evaluate the log_probs using the current self.pf - this is usually for @@ -181,7 +181,7 @@ def get_scores( module_output = self.pf(states) pf_dist = self.pf.to_probability_distribution(states, module_output) - if has_log_probs(transitions) and not recalculate_all: + if has_log_probs(transitions) and not recalculate_all_logprobs: valid_log_pf_actions = transitions[mask].log_probs else: # Evaluate the log PF of the actions sampled off policy. diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index 45db346c..b4abf3a5 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -42,7 +42,7 @@ def __init__( self.log_reward_clip_min = log_reward_clip_min def loss( - self, env: Env, trajectories: Trajectories, recalculate_all: bool = False + self, env: Env, trajectories: Trajectories, recalculate_all_logprobs: bool = False ) -> TT[0, float]: """Trajectory balance loss. @@ -54,7 +54,7 @@ def loss( """ del env # unused _, _, scores = self.get_trajectories_scores( - trajectories, recalculate_all=recalculate_all + trajectories, recalculate_all_logprobs=recalculate_all_logprobs ) loss = (scores + self.logZ).pow(2).mean() if torch.isnan(loss): @@ -83,7 +83,7 @@ def __init__( self.log_reward_clip_min = log_reward_clip_min def loss( - self, env: Env, trajectories: Trajectories, recalculate_all: bool = False + self, env: Env, trajectories: Trajectories, recalculate_all_logprobs: bool = False ) -> TT[0, float]: """Log Partition Variance loss. @@ -92,7 +92,7 @@ def loss( """ del env # unused _, _, scores = self.get_trajectories_scores( - trajectories, recalculate_all=recalculate_all + trajectories, recalculate_all_logprobs=recalculate_all_logprobs ) loss = (scores - scores.mean()).pow(2).mean() if torch.isnan(loss): diff --git a/testing/test_parametrizations_and_losses.py b/testing/test_parametrizations_and_losses.py index 5904ddd4..95b69bc6 100644 --- a/testing/test_parametrizations_and_losses.py +++ b/testing/test_parametrizations_and_losses.py @@ -76,7 +76,7 @@ def test_get_pfs_and_pbs(env_name: str, preprocessor_name: str): log_pfs_on, log_pbs_on = gflownet_on.get_pfs_and_pbs(trajectories) log_pfs_off, log_pbs_off = gflownet_off.get_pfs_and_pbs( - trajectories, recalculate_all=True + trajectories, recalculate_all_logprobs=True ) @@ -92,7 +92,7 @@ def test_get_scores(env_name: str, preprocessor_name: str): gflownet_off = TBGFlowNet(pf=pf_estimator, pb=pb_estimator) scores_on = gflownet_on.get_trajectories_scores(trajectories) scores_off = gflownet_off.get_trajectories_scores( - trajectories, recalculate_all=True + trajectories, recalculate_all_logprobs=True ) assert all( [