diff --git a/tests/integration_tests/models/test_simple_mnl.py b/tests/integration_tests/models/test_simple_mnl.py index c84d901d..c94dbcff 100644 --- a/tests/integration_tests/models/test_simple_mnl.py +++ b/tests/integration_tests/models/test_simple_mnl.py @@ -1,5 +1,7 @@ """Tests SimpleMNL.""" +import tensorflow as tf + from choice_learn.datasets import load_swissmetro from choice_learn.models import SimpleMNL @@ -18,10 +20,11 @@ def test_simple_mnl_lbfgs_fit_with_lbfgs(): def test_simple_mnl_lbfgs_fit_with_adam(): """Tests that SimpleMNL can fit with Adam.""" + tf.config.run_functions_eagerly(True) global dataset model = SimpleMNL(epochs=20, optimizer="adam", batch_size=256) - model.fit(dataset) + model.fit(dataset, get_report=True) model.evaluate(dataset) assert model.evaluate(dataset) < 1.0