diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index 1b8da40a..f07835c3 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -253,31 +253,6 @@ def get_scores( i, ) - # preds = ( - # log_pf_trajectories_cum[i:] - # - log_pf_trajectories_cum[:-i] - # + current_log_state_flows - # ) - - # targets = torch.full_like(preds, fill_value=-float("inf")) - # assert trajectories.log_rewards is not None - # log_rewards = trajectories.log_rewards[trajectories.when_is_done >= i] - - # targets.T[is_terminal_mask[i - 1 :].T] = log_rewards - - # # For now, the targets contain the log-rewards of the ending sub trajectories - # # We need to add to that the log-probabilities of the backward actions up-to - # # the sub-trajectory's terminating state - # if i > 1: - # targets[is_terminal_mask[i - 1 :]] += ( - # log_pb_trajectories_cum[i - 1 :] - log_pb_trajectories_cum[: -i + 1] - # )[:-1][is_terminal_mask[i - 1 :]] - - # # The following creates the targets for the non-finishing sub-trajectories - # targets[~full_mask[i - 1 :]] = ( - # log_pb_trajectories_cum[i:] - log_pb_trajectories_cum[:-i] - # )[:-1][~full_mask[i - 1 : -1]] + log_state_flows[i:][~sink_states_mask[i:]] - flattening_mask = trajectories.when_is_done.lt( torch.arange( i,