From 2ec81cb091951edb3dbbc4d601f7e8fb5e5700dd Mon Sep 17 00:00:00 2001 From: Vincent Auriau Date: Wed, 25 Sep 2024 13:50:37 +0200 Subject: [PATCH] Test eager mode for some model tests (#161) --- .gitignore | 2 +- choice_learn/models/base_model.py | 24 +++++----- choice_learn/models/simple_mnl.py | 2 +- pyproject.toml | 2 +- tests/unit_tests/models/test_simplemnl.py | 53 ++++++++++++++++++++++- 5 files changed, 69 insertions(+), 14 deletions(-) diff --git a/.gitignore b/.gitignore index 77c3a6cd..d3ae9950 100644 --- a/.gitignore +++ b/.gitignore @@ -141,4 +141,4 @@ secrets/* choice_learn/datasets/data/expedia.csv choice_learn/datasets/cache/* !choice_learn/datasets/cache/.gitkeep -choice_learn/datasets/data/lmpc.dat +choice_learn/datasets/data/lpmc.dat diff --git a/choice_learn/models/base_model.py b/choice_learn/models/base_model.py index 16dc8617..4e891767 100644 --- a/choice_learn/models/base_model.py +++ b/choice_learn/models/base_model.py @@ -519,15 +519,18 @@ def save_model(self, path): path : str path to the folder where to save the model """ - if not os.exists(path): + if not os.path.exists(path): Path(path).mkdir(parents=True) for i, weight in enumerate(self.trainable_weights): - tf.keras.savedmodel.save(Path(path) / f"weight_{i}") + np.save(Path(path) / f"weight_{i}.npy", weight.numpy()) # To improve for non-string attributes - params = self.__dict__ - json.dump(Path(path) / "params.json", params) + params = {} + for k, v in self.__dict__.items(): + if isinstance(v, (int, float, str, dict)): + params[k] = v + json.dump(params, open(os.path.join(path, "params.json"), "w")) # Save optimizer state @@ -546,21 +549,22 @@ def load_model(cls, path): Loaded ChoiceModel """ obj = cls() - obj.trainable_weights = [] + obj._trainable_weights = [] + i = 0 - weight_path = f"weight_{i}" + weight_path = f"weight_{i}.npy" while weight_path in os.listdir(path): - obj.trainable_weights.append(tf.keras.load_model.load(Path(path) / weight_path)) + obj._trainable_weights.append(tf.Variable(np.load(Path(path) / weight_path))) i += 1 - weight_path = f"weight_{i}" + weight_path = f"weight_{i}.npy" # To improve for non string attributes - params = json.load(Path(path) / "params.json") + params = json.load(open(Path(path) / "params.json", "r")) for k, v in params.items(): setattr(obj, k, v) # Load optimizer step - return cls + return obj def predict_probas(self, choice_dataset, batch_size=-1): """Predicts the choice probabilities for each choice and each product of a ChoiceDataset. diff --git a/choice_learn/models/simple_mnl.py b/choice_learn/models/simple_mnl.py index e0dc1386..c075bc52 100644 --- a/choice_learn/models/simple_mnl.py +++ b/choice_learn/models/simple_mnl.py @@ -351,5 +351,5 @@ def clone(self): if hasattr(self, "_items_features_names"): clone._items_features_names = self._items_features_names if hasattr(self, "_shared_features_names"): - clone._contexts_features_names = self._contexts_features_names + clone._shared_features_names = self._shared_features_names return clone diff --git a/pyproject.toml b/pyproject.toml index c7f4d509..a6efc98a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ select = [ "PTH", "PD", ] # See: https://beta.ruff.rs/docs/rules/ -ignore = ["D203", "D213", "ANN101", "ANN102", "ANN204", "ANN001", "ANN002", "ANN202", "ANN201", "ANN206", "ANN003", "PTH100", "PTH118", "PTH123","PTH113", "PTH104"] +ignore = ["D203", "D213", "ANN101", "ANN102", "ANN204", "ANN001", "ANN002", "ANN202", "ANN201", "ANN206", "ANN003", "PTH100", "PTH110", "PTH118", "PTH123","PTH113", "PTH104"] line-length = 100 exclude = [ ".bzr", diff --git a/tests/unit_tests/models/test_simplemnl.py b/tests/unit_tests/models/test_simplemnl.py index 50932482..9cb7c1d1 100644 --- a/tests/unit_tests/models/test_simplemnl.py +++ b/tests/unit_tests/models/test_simplemnl.py @@ -1,6 +1,9 @@ """Tests for the SimpleMNL model.""" +import shutil + import numpy as np +import tensorflow as tf from choice_learn.data import ChoiceDataset from choice_learn.models import SimpleMNL @@ -31,7 +34,14 @@ def test_simplemnl_instantiation(): def test_fit_lbfgs(): """Tests instantiation with item-full and fit with lbfgs.""" - model = SimpleMNL(intercept="item-full", optimizer="lbfgs", epochs=20) + tf.config.run_functions_eagerly(True) + model = SimpleMNL( + intercept="item-full", + optimizer="lbfgs", + epochs=20, + regularization="l2", + regularization_strength=0.01, + ) model.instantiate(n_items=3, n_items_features=2, n_shared_features=3) nll_b = model.evaluate(test_dataset) model.fit(test_dataset, get_report=True) @@ -55,6 +65,7 @@ def test_fit_lbfgs(): def test_fit_adam(): """Tests instantiation with item and fit with Adam.""" + tf.config.run_functions_eagerly(True) model = SimpleMNL(intercept="item", optimizer="Adam", epochs=100, lr=0.1) model.instantiate(n_items=3, n_items_features=2, n_shared_features=3) nll_b = model.evaluate(test_dataset) @@ -63,3 +74,43 @@ def test_fit_adam(): assert nll_a < nll_b assert model.report.to_numpy().shape == (7, 5) + + +def test_fit_adam_weights(): + """Tests instantiation with item and fit with Adam.""" + tf.config.run_functions_eagerly(True) + model = SimpleMNL( + intercept="item", + optimizer="Adam", + epochs=100, + lr=0.1, + regularization="l1", + regularization_strength=0.01, + ) + model.instantiate(n_items=3, n_items_features=2, n_shared_features=3) + nll_b = model.evaluate(test_dataset) + model.fit( + test_dataset, + sample_weight=np.array([0.2, 0.4, 0.8, 1.0]), + get_report=True, + val_dataset=test_dataset, + ) + nll_a = model.evaluate(test_dataset, batch_size=-1) + nll_c = model.evaluate(test_dataset, batch_size=3) + assert nll_a < nll_b + assert nll_c == nll_a + + assert model.report.to_numpy().shape == (7, 5) + + +def test_save_load(): + """Tests instantiation with item and fit with Adam.""" + model = SimpleMNL(intercept="item", optimizer="Adam", epochs=100, lr=0.1) + model.instantiate(n_items=3, n_items_features=2, n_shared_features=3) + nll_b = model.evaluate(test_dataset) + model.save_model("test_save") + loaded_model = SimpleMNL.load_model("test_save") + nll_a = loaded_model.evaluate(test_dataset) + + assert nll_a == nll_b + shutil.rmtree("test_save")