diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index e38bb10a..b5aa2929 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -153,8 +153,8 @@ def get_pfs_and_pbs( if self.off_policy: # We re-use the values calculated in .sample_trajectories(). if trajectories.estimator_outputs is not None: - estimator_outputs = trajectories.estimator_outputs[ - ~trajectories.actions.is_dummy + estimator_outputs = trajectories.estimator_outputs[ # TODO: This contains `inf` when we use the new `stack_states` method in `samplers.py`! + ~trajectories.actions.is_dummy # And this causes later failures (p_f is not finite). ] else: raise Exception(