diff --git a/tests/unit_tests/tools/test_base.py b/tests/unit_tests/tools/test_base.py index e79e3bb..c766bdc 100644 --- a/tests/unit_tests/tools/test_base.py +++ b/tests/unit_tests/tools/test_base.py @@ -2,38 +2,47 @@ import numpy as np -from choice_learn.toolbox.assortment_optimizer import MNLAssortmentOptimizer, LatentClassAssortmentOptimizer, LatentClassPricingOptimizer +from choice_learn.toolbox.assortment_optimizer import ( + LatentClassAssortmentOptimizer, + LatentClassPricingOptimizer, + MNLAssortmentOptimizer, +) solvers = ["or-tools"] + def test_mnl_assort_instantiate(): """Test instantiation with both solvers.""" for solv in solvers: MNLAssortmentOptimizer( - solver=solv, - utilities=np.array([1., 2., 3.]), - itemwise_values=np.array([0.5, 0.5, 0.5]), - assortment_size=2) + solver=solv, + utilities=np.array([1.0, 2.0, 3.0]), + itemwise_values=np.array([0.5, 0.5, 0.5]), + assortment_size=2, + ) def test_lc_assort_instantiate(): """Test instantiation with both solvers.""" for solv in solvers: LatentClassAssortmentOptimizer( - solver=solv, - class_weights=np.array([.2, .8]), - class_utilities=np.array([[1., 2., 3.], [3., 2., 1.]]), - itemwise_values=np.array([0.5, 0.5, 0.5]), - assortment_size=2) + solver=solv, + class_weights=np.array([0.2, 0.8]), + class_utilities=np.array([[1.0, 2.0, 3.0], [3.0, 2.0, 1.0]]), + itemwise_values=np.array([0.5, 0.5, 0.5]), + assortment_size=2, + ) def test_lc_pricing_instantiate(): """Test instantiation with both solvers.""" for solv in solvers: LatentClassPricingOptimizer( - solver=solv, - class_weights=np.array([.2, .8]), - class_utilities=np.array([[[1., 1.1], [2., 2.1], [3., 3.1]], - [[3., 3.1], [2., 2.1], [1., 1.1]]]), - itemwise_values=np.array([[0.5, 1.2], [0.5, 1.2], [0.5, 1.2]]), - assortment_size=2) \ No newline at end of file + solver=solv, + class_weights=np.array([0.2, 0.8]), + class_utilities=np.array( + [[[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]], [[3.0, 3.1], [2.0, 2.1], [1.0, 1.1]]] + ), + itemwise_values=np.array([[0.5, 1.2], [0.5, 1.2], [0.5, 1.2]]), + assortment_size=2, + )