diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index fad0e8b5..4e07ab77 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -152,6 +152,8 @@ def get_pfs_and_pbs( if not isinstance(estimator_outputs, type(None)): idx = torch.ones(trajectories.actions.batch_shape).bool() estimator_outputs = estimator_outputs[idx] + else: + estimator_outputs = self.pf(valid_states) # TODO: Remove This is left here to compare the recomputed values with the # carried forward values -- which strangely don't always seem to