Skip to content

Commit

Permalink
fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Nov 27, 2023
1 parent 038b67b commit 8388362
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions testing/test_gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_trajectory_based_gflownet_generic():
)
pb_estimator = BoxPBEstimator(env=env, module=pb_module, n_components=1)

gflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator)
gflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator, off_policy=False)
mock_trajectories = Trajectories(env)

result = gflownet.to_training_samples(mock_trajectories)
Expand Down Expand Up @@ -79,7 +79,7 @@ def test_pytorch_inheritance():
)
pb_estimator = BoxPBEstimator(env=env, module=pb_module, n_components=1)

tbgflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator)
tbgflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator, off_policy=False)
assert hasattr(
tbgflownet.parameters(), "__iter__"
), "Expected gflownet to have iterable parameters() method inherited from nn.Module"
Expand Down

0 comments on commit 8388362

Please sign in to comment.