From 8388362b672f00152d6c8026d0f19217ebc2eef7 Mon Sep 17 00:00:00 2001 From: Joseph Date: Mon, 27 Nov 2023 13:13:12 -0500 Subject: [PATCH] fixed tests --- testing/test_gflownet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/testing/test_gflownet.py b/testing/test_gflownet.py index 5e545f46..35642020 100644 --- a/testing/test_gflownet.py +++ b/testing/test_gflownet.py @@ -27,7 +27,7 @@ def test_trajectory_based_gflownet_generic(): ) pb_estimator = BoxPBEstimator(env=env, module=pb_module, n_components=1) - gflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator) + gflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator, off_policy=False) mock_trajectories = Trajectories(env) result = gflownet.to_training_samples(mock_trajectories) @@ -79,7 +79,7 @@ def test_pytorch_inheritance(): ) pb_estimator = BoxPBEstimator(env=env, module=pb_module, n_components=1) - tbgflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator) + tbgflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator, off_policy=False) assert hasattr( tbgflownet.parameters(), "__iter__" ), "Expected gflownet to have iterable parameters() method inherited from nn.Module"