From e052c826b4e4946ad5664700540ff93dffa7c682 Mon Sep 17 00:00:00 2001 From: Joseph Date: Mon, 20 Nov 2023 22:53:18 -0500 Subject: [PATCH] added back in default recomputing behaviour for pf in off policy mode. --- src/gfn/gflownet/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index fad0e8b5..4e07ab77 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -152,6 +152,8 @@ def get_pfs_and_pbs( if not isinstance(estimator_outputs, type(None)): idx = torch.ones(trajectories.actions.batch_shape).bool() estimator_outputs = estimator_outputs[idx] + else: + estimator_outputs = self.pf(valid_states) # TODO: Remove This is left here to compare the recomputed values with the # carried forward values -- which strangely don't always seem to