diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 321306e0..89722ab9 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -11,11 +11,7 @@ from gfn.modules import GFNModule from gfn.samplers import Sampler from gfn.states import States -from gfn.utils.common import has_log_probs -from gfn.utils.handlers import ( - has_conditioning_exception_handler, - no_conditioning_exception_handler, -) +from gfn.utils.prob_calculations import get_traj_pfs_and_pbs TrainingSampleType = TypeVar( "TrainingSampleType", bound=Union[Container, tuple[States, ...]] @@ -145,6 +141,7 @@ def get_pfs_and_pbs( trajectories: Trajectories to evaluate. fill_value: Value to use for invalid states (i.e. $s_f$ that is added to shorter trajectories). + recalculate_all_logprobs: Whether to re-evaluate all logprobs. Returns: A tuple of float tensors of shape (max_length, n_trajectories) containing the log_pf and log_pb for each action in each trajectory. The first one can be None. @@ -153,103 +150,9 @@ def get_pfs_and_pbs( ValueError: if the trajectories are backward. AssertionError: when actions and states dimensions mismatch. """ - # fill value is the value used for invalid states (sink state usually) - if trajectories.is_backward: - raise ValueError("Backward trajectories are not supported") - - valid_states = trajectories.states[~trajectories.states.is_sink_state] - valid_actions = trajectories.actions[~trajectories.actions.is_dummy] - - # uncomment next line for debugging - # assert trajectories.states.is_sink_state[:-1].equal(trajectories.actions.is_dummy) - - if valid_states.batch_shape != tuple(valid_actions.batch_shape): - raise AssertionError("Something wrong happening with log_pf evaluations") - - if has_log_probs(trajectories) and not recalculate_all_logprobs: - log_pf_trajectories = trajectories.log_probs - else: - if ( - trajectories.estimator_outputs is not None - and not recalculate_all_logprobs - ): - estimator_outputs = trajectories.estimator_outputs[ - ~trajectories.actions.is_dummy - ] - else: - if trajectories.conditioning is not None: - cond_dim = (-1,) * len(trajectories.conditioning.shape) - traj_len = trajectories.states.tensor.shape[0] - masked_cond = trajectories.conditioning.unsqueeze(0).expand( - (traj_len,) + cond_dim - )[~trajectories.states.is_sink_state] - - # Here, we pass all valid states, i.e., non-sink states. - with has_conditioning_exception_handler("pf", self.pf): - estimator_outputs = self.pf(valid_states, masked_cond) - else: - # Here, we pass all valid states, i.e., non-sink states. - with no_conditioning_exception_handler("pf", self.pf): - estimator_outputs = self.pf(valid_states) - - # Calculates the log PF of the actions sampled off policy. - valid_log_pf_actions = self.pf.to_probability_distribution( - valid_states, estimator_outputs - ).log_prob( - valid_actions.tensor - ) # Using the actions sampled off-policy. - log_pf_trajectories = torch.full_like( - trajectories.actions.tensor[..., 0], - fill_value=fill_value, - dtype=torch.float, - ) - log_pf_trajectories[~trajectories.actions.is_dummy] = valid_log_pf_actions - - non_initial_valid_states = valid_states[~valid_states.is_initial_state] - non_exit_valid_actions = valid_actions[~valid_actions.is_exit] - - # Using all non-initial states, calculate the backward policy, and the logprobs - # of those actions. - if trajectories.conditioning is not None: - # We need to index the conditioning vector to broadcast over the states. - cond_dim = (-1,) * len(trajectories.conditioning.shape) - traj_len = trajectories.states.tensor.shape[0] - masked_cond = trajectories.conditioning.unsqueeze(0).expand( - (traj_len,) + cond_dim - )[~trajectories.states.is_sink_state][~valid_states.is_initial_state] - - # Pass all valid states, i.e., non-sink states, except the initial state. - with has_conditioning_exception_handler("pb", self.pb): - estimator_outputs = self.pb(non_initial_valid_states, masked_cond) - else: - # Pass all valid states, i.e., non-sink states, except the initial state. - with no_conditioning_exception_handler("pb", self.pb): - estimator_outputs = self.pb(non_initial_valid_states) - - valid_log_pb_actions = self.pb.to_probability_distribution( - non_initial_valid_states, estimator_outputs - ).log_prob(non_exit_valid_actions.tensor) - - log_pb_trajectories = torch.full_like( - trajectories.actions.tensor[..., 0], - fill_value=fill_value, - dtype=torch.float, + return get_traj_pfs_and_pbs( + self.pf, self.pb, trajectories, fill_value, recalculate_all_logprobs ) - log_pb_trajectories_slice = torch.full_like( - valid_actions.tensor[..., 0], fill_value=fill_value, dtype=torch.float - ) - log_pb_trajectories_slice[~valid_actions.is_exit] = valid_log_pb_actions - log_pb_trajectories[~trajectories.actions.is_dummy] = log_pb_trajectories_slice - - assert log_pf_trajectories.shape == ( - trajectories.max_length, - trajectories.n_trajectories, - ) - assert log_pb_trajectories.shape == ( - trajectories.max_length, - trajectories.n_trajectories, - ) - return log_pf_trajectories, log_pb_trajectories def get_trajectories_scores( self, @@ -265,7 +168,6 @@ def get_trajectories_scores( Returns: A tuple of float tensors of shape (n_trajectories,) containing the total log_pf, total log_pb, and the total log-likelihood of the trajectories. - """ log_pf_trajectories, log_pb_trajectories = self.get_pfs_and_pbs( trajectories, recalculate_all_logprobs=recalculate_all_logprobs diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index d5002220..8664c9de 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -12,6 +12,7 @@ has_conditioning_exception_handler, no_conditioning_exception_handler, ) +from gfn.utils.prob_calculations import get_trans_pfs_and_pbs def check_compatibility(states, actions, transitions): @@ -78,6 +79,13 @@ def logF_parameters(self): ) ) + def get_pfs_and_pbs( + self, transitions: Transitions, recalculate_all_logprobs: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: + return get_trans_pfs_and_pbs( + self.pf, self.pb, transitions, recalculate_all_logprobs + ) + def get_scores( self, env: Env, transitions: Transitions, recalculate_all_logprobs: bool = False ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -101,6 +109,7 @@ def get_scores( """ if transitions.is_backward: raise ValueError("Backward transitions are not supported") + states = transitions.states actions = transitions.actions @@ -108,63 +117,31 @@ def get_scores( # assert transitions.states.is_sink_state.equal(transitions.actions.is_dummy) check_compatibility(states, actions, transitions) - if has_log_probs(transitions) and not recalculate_all_logprobs: - valid_log_pf_actions = transitions.log_probs - else: - # Evaluate the log PF of the actions, with optional conditioning. - # TODO: Inefficient duplication in case of tempered policy - # The Transitions container should then have some - # estimator_outputs attribute as well, to avoid duplication here ? - # See (#156). - if transitions.conditioning is not None: - with has_conditioning_exception_handler("pf", self.pf): - module_output = self.pf(states, transitions.conditioning) - else: - with no_conditioning_exception_handler("pf", self.pf): - module_output = self.pf(states) - - valid_log_pf_actions = self.pf.to_probability_distribution( - states, module_output - ).log_prob(actions.tensor) + log_pf_actions, log_pb_actions = self.get_pfs_and_pbs( + transitions, recalculate_all_logprobs + ) # LogF is potentially a conditional computation. if transitions.conditioning is not None: with has_conditioning_exception_handler("logF", self.logF): - valid_log_F_s = self.logF(states, transitions.conditioning).squeeze(-1) + log_F_s = self.logF(states, transitions.conditioning).squeeze(-1) else: with no_conditioning_exception_handler("logF", self.logF): - valid_log_F_s = self.logF(states).squeeze(-1) + log_F_s = self.logF(states).squeeze(-1) if self.forward_looking: log_rewards = env.log_reward(states) # TODO: RM unsqueeze(-1) ? if math.isfinite(self.log_reward_clip_min): log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) - valid_log_F_s = valid_log_F_s + log_rewards + log_F_s = log_F_s + log_rewards - preds = valid_log_pf_actions + valid_log_F_s - targets = torch.zeros_like(preds) + preds = log_pf_actions + log_F_s # uncomment next line for debugging # assert transitions.next_states.is_sink_state.equal(transitions.is_done) # automatically removes invalid transitions (i.e. s_f -> s_f) valid_next_states = transitions.next_states[~transitions.is_done] - non_exit_actions = actions[~actions.is_exit] - - # Evaluate the log PB of the actions, with optional conditioning. - if transitions.conditioning is not None: - with has_conditioning_exception_handler("pb", self.pb): - module_output = self.pb( - valid_next_states, transitions.conditioning[~transitions.is_done] - ) - else: - with no_conditioning_exception_handler("pb", self.pb): - module_output = self.pb(valid_next_states) - - valid_log_pb_actions = self.pb.to_probability_distribution( - valid_next_states, module_output - ).log_prob(non_exit_actions.tensor) - valid_transitions_is_done = transitions.is_done[ ~transitions.states.is_sink_state ] @@ -179,23 +156,21 @@ def get_scores( with no_conditioning_exception_handler("logF", self.logF): valid_log_F_s_next = self.logF(valid_next_states).squeeze(-1) - targets[~valid_transitions_is_done] = valid_log_pb_actions - log_pb_actions = targets.clone() - targets[~valid_transitions_is_done] += valid_log_F_s_next + log_F_s_next = torch.zeros_like(log_pb_actions) + log_F_s_next[~valid_transitions_is_done] += valid_log_F_s_next assert transitions.log_rewards is not None valid_transitions_log_rewards = transitions.log_rewards[ ~transitions.states.is_sink_state ] - targets[valid_transitions_is_done] = valid_transitions_log_rewards[ + log_F_s_next[valid_transitions_is_done] = valid_transitions_log_rewards[ valid_transitions_is_done ] + targets = log_pb_actions + log_F_s_next scores = preds - targets - assert valid_log_pf_actions.shape == (transitions.n_transitions,) - assert log_pb_actions.shape == (transitions.n_transitions,) assert scores.shape == (transitions.n_transitions,) - return valid_log_pf_actions, log_pb_actions, scores + return log_pf_actions, log_pb_actions, scores def loss(self, env: Env, transitions: Transitions) -> torch.Tensor: """Detailed balance loss. diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 4351b462..877e77f4 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -271,7 +271,7 @@ def _forward_trunk( return out - def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: + def forward(self, states: States, conditioning: torch.Tensor) -> torch.Tensor: """Forward pass of the module. Args: diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index eb224fbf..819620f0 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -35,7 +35,7 @@ def sample_actions( save_estimator_outputs: bool = False, save_logprobs: bool = True, **policy_kwargs: Any, - ) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None,]: + ) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None]: """Samples actions from the given states. Args: diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index e0d6d3b0..f4948d0d 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -3,7 +3,7 @@ class UnsqueezedCategorical(Categorical): - """Samples froma categorical distribution with an unsqueezed final dimension. + """Samples from a categorical distribution with an unsqueezed final dimension. Samples are unsqueezed to be of shape (batch_size, 1) instead of (batch_size,). diff --git a/src/gfn/utils/prob_calculations.py b/src/gfn/utils/prob_calculations.py new file mode 100644 index 00000000..550e1540 --- /dev/null +++ b/src/gfn/utils/prob_calculations.py @@ -0,0 +1,232 @@ +from typing import Tuple + +import torch + +from gfn.containers import Trajectories, Transitions +from gfn.modules import GFNModule +from gfn.states import States +from gfn.utils.common import has_log_probs +from gfn.utils.handlers import ( + has_conditioning_exception_handler, + no_conditioning_exception_handler, +) + + +def check_cond_forward( + module: GFNModule, + module_name: str, + states: States, + condition: torch.Tensor | None = None, +) -> torch.Tensor: + if condition is not None: + with has_conditioning_exception_handler(module_name, module): + return module(states, condition) + else: + with no_conditioning_exception_handler(module_name, module): + return module(states) + + +######################### +##### Trajectories ##### +######################### + + +def get_traj_pfs_and_pbs( + pf: GFNModule, + pb: GFNModule, + trajectories: Trajectories, + fill_value: float = 0.0, + recalculate_all_logprobs: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + # fill value is the value used for invalid states (sink state usually) + if trajectories.is_backward: + raise ValueError("Backward trajectories are not supported") + + # uncomment next line for debugging + # assert trajectories.states.is_sink_state[:-1].equal(trajectories.actions.is_dummy) + + log_pf_trajectories = get_traj_pfs( + pf, + trajectories, + fill_value=fill_value, + recalculate_all_logprobs=recalculate_all_logprobs, + ) + log_pb_trajectories = get_traj_pbs(pb, trajectories, fill_value=fill_value) + + assert log_pf_trajectories.shape == ( + trajectories.max_length, + trajectories.n_trajectories, + ) + assert log_pb_trajectories.shape == ( + trajectories.max_length, + trajectories.n_trajectories, + ) + + return log_pf_trajectories, log_pb_trajectories + + +def get_traj_pfs( + pf: GFNModule, + trajectories: Trajectories, + fill_value: float = 0.0, + recalculate_all_logprobs: bool = False, +) -> torch.Tensor: + state_mask = ~trajectories.states.is_sink_state + action_mask = ~trajectories.actions.is_dummy + + valid_states = trajectories.states[state_mask] + valid_actions = trajectories.actions[action_mask] + + if valid_states.batch_shape != tuple(valid_actions.batch_shape): + raise AssertionError("Something wrong happening with log_pf evaluations") + + if has_log_probs(trajectories) and not recalculate_all_logprobs: + log_pf_trajectories = trajectories.log_probs + else: + if trajectories.estimator_outputs is not None and not recalculate_all_logprobs: + estimator_outputs = trajectories.estimator_outputs[action_mask] + else: + masked_cond = None + if trajectories.conditioning is not None: + cond_dim = (-1,) * len(trajectories.conditioning.shape) + traj_len = trajectories.states.tensor.shape[0] + masked_cond = trajectories.conditioning.unsqueeze(0).expand( + (traj_len,) + cond_dim + )[state_mask] + + estimator_outputs = check_cond_forward(pf, "pf", valid_states, masked_cond) + + # Calculates the log PF of the actions sampled off policy. + valid_log_pf_actions = pf.to_probability_distribution( + valid_states, estimator_outputs + ).log_prob( + valid_actions.tensor + ) # Using the actions sampled off-policy. + + log_pf_trajectories = torch.full_like( + trajectories.actions.tensor[..., 0], + fill_value=fill_value, + dtype=torch.float, + ) + log_pf_trajectories[action_mask] = valid_log_pf_actions + + return log_pf_trajectories + + +def get_traj_pbs( + pb: GFNModule, trajectories: Trajectories, fill_value: float = 0.0 +) -> torch.Tensor: + # Note the different mask for valid states and actions compared to the pf case. + state_mask = ( + ~trajectories.states.is_sink_state & ~trajectories.states.is_initial_state + ) + action_mask = ~trajectories.actions.is_dummy & ~trajectories.actions.is_exit + + valid_states = trajectories.states[state_mask] + valid_actions = trajectories.actions[action_mask] + + if valid_states.batch_shape != tuple(valid_actions.batch_shape): + raise AssertionError("Something wrong happening with log_pf evaluations") + + # Using all non-initial states, calculate the backward policy, and the logprobs + # of those actions. + masked_cond = None + if trajectories.conditioning is not None: + # We need to index the conditioning vector to broadcast over the states. + cond_dim = (-1,) * len(trajectories.conditioning.shape) + traj_len = trajectories.states.tensor.shape[0] + masked_cond = trajectories.conditioning.unsqueeze(0).expand( + (traj_len,) + cond_dim + )[state_mask] + + estimator_outputs = check_cond_forward(pb, "pb", valid_states, masked_cond) + + valid_log_pb_actions = pb.to_probability_distribution( + valid_states, estimator_outputs + ).log_prob(valid_actions.tensor) + + log_pb_trajectories = torch.full_like( + trajectories.actions.tensor[..., 0], + fill_value=fill_value, + dtype=torch.float, + ) + log_pb_trajectories[action_mask] = valid_log_pb_actions + + return log_pb_trajectories + + +######################## +##### Transitions ##### +######################## + + +def get_trans_pfs_and_pbs( + pf: GFNModule, + pb: GFNModule, + transitions: Transitions, + recalculate_all_logprobs: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + if transitions.is_backward: + raise ValueError("Backward transitions are not supported") + + log_pf_transitions = get_trans_pfs(pf, transitions, recalculate_all_logprobs) + log_pb_transitions = get_trans_pbs(pb, transitions) + + assert log_pf_transitions.shape == (transitions.n_transitions,) + assert log_pb_transitions.shape == (transitions.n_transitions,) + + return log_pf_transitions, log_pb_transitions + + +def get_trans_pfs( + pf: GFNModule, transitions: Transitions, recalculate_all_logprobs: bool = False +) -> torch.Tensor: + states = transitions.states + actions = transitions.actions + + if has_log_probs(transitions) and not recalculate_all_logprobs: + log_pf_actions = transitions.log_probs + else: + # Evaluate the log PF of the actions, with optional conditioning. + # TODO: Inefficient duplication in case of tempered policy + # The Transitions container should then have some + # estimator_outputs attribute as well, to avoid duplication here ? + # See (#156). + estimator_outputs = check_cond_forward( + pf, "pf", states, transitions.conditioning + ) + + log_pf_actions = pf.to_probability_distribution( + states, estimator_outputs + ).log_prob(actions.tensor) + + return log_pf_actions + + +def get_trans_pbs(pb: GFNModule, transitions: Transitions) -> torch.Tensor: + # automatically removes invalid transitions (i.e. s_f -> s_f) + valid_next_states = transitions.next_states[~transitions.is_done] + non_exit_actions = transitions.actions[~transitions.actions.is_exit] + + # Evaluate the log PB of the actions, with optional conditioning. + masked_cond = ( + transitions.conditioning[~transitions.is_done] + if transitions.conditioning is not None + else None + ) + estimator_outputs = check_cond_forward(pb, "pb", valid_next_states, masked_cond) + + valid_log_pb_actions = pb.to_probability_distribution( + valid_next_states, estimator_outputs + ).log_prob(non_exit_actions.tensor) + + # Evaluate the log PB of the actions. + log_pb_actions = torch.zeros( + (transitions.n_transitions,), + dtype=torch.float, + device=valid_log_pb_actions.device, + ) + + log_pb_actions[~transitions.is_done] = valid_log_pb_actions + + return log_pb_actions