Skip to content

Commit

Permalink
Test eager mode for some model tests (#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentAuriau authored Sep 25, 2024
1 parent ae059f8 commit 2ec81cb
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,4 @@ secrets/*
choice_learn/datasets/data/expedia.csv
choice_learn/datasets/cache/*
!choice_learn/datasets/cache/.gitkeep
choice_learn/datasets/data/lmpc.dat
choice_learn/datasets/data/lpmc.dat
24 changes: 14 additions & 10 deletions choice_learn/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,15 +519,18 @@ def save_model(self, path):
path : str
path to the folder where to save the model
"""
if not os.exists(path):
if not os.path.exists(path):
Path(path).mkdir(parents=True)

for i, weight in enumerate(self.trainable_weights):
tf.keras.savedmodel.save(Path(path) / f"weight_{i}")
np.save(Path(path) / f"weight_{i}.npy", weight.numpy())

# To improve for non-string attributes
params = self.__dict__
json.dump(Path(path) / "params.json", params)
params = {}
for k, v in self.__dict__.items():
if isinstance(v, (int, float, str, dict)):
params[k] = v
json.dump(params, open(os.path.join(path, "params.json"), "w"))

# Save optimizer state

Expand All @@ -546,21 +549,22 @@ def load_model(cls, path):
Loaded ChoiceModel
"""
obj = cls()
obj.trainable_weights = []
obj._trainable_weights = []

i = 0
weight_path = f"weight_{i}"
weight_path = f"weight_{i}.npy"
while weight_path in os.listdir(path):
obj.trainable_weights.append(tf.keras.load_model.load(Path(path) / weight_path))
obj._trainable_weights.append(tf.Variable(np.load(Path(path) / weight_path)))
i += 1
weight_path = f"weight_{i}"
weight_path = f"weight_{i}.npy"

# To improve for non string attributes
params = json.load(Path(path) / "params.json")
params = json.load(open(Path(path) / "params.json", "r"))
for k, v in params.items():
setattr(obj, k, v)

# Load optimizer step
return cls
return obj

def predict_probas(self, choice_dataset, batch_size=-1):
"""Predicts the choice probabilities for each choice and each product of a ChoiceDataset.
Expand Down
2 changes: 1 addition & 1 deletion choice_learn/models/simple_mnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,5 +351,5 @@ def clone(self):
if hasattr(self, "_items_features_names"):
clone._items_features_names = self._items_features_names
if hasattr(self, "_shared_features_names"):
clone._contexts_features_names = self._contexts_features_names
clone._shared_features_names = self._shared_features_names
return clone
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ select = [
"PTH",
"PD",
] # See: https://beta.ruff.rs/docs/rules/
ignore = ["D203", "D213", "ANN101", "ANN102", "ANN204", "ANN001", "ANN002", "ANN202", "ANN201", "ANN206", "ANN003", "PTH100", "PTH118", "PTH123","PTH113", "PTH104"]
ignore = ["D203", "D213", "ANN101", "ANN102", "ANN204", "ANN001", "ANN002", "ANN202", "ANN201", "ANN206", "ANN003", "PTH100", "PTH110", "PTH118", "PTH123","PTH113", "PTH104"]
line-length = 100
exclude = [
".bzr",
Expand Down
53 changes: 52 additions & 1 deletion tests/unit_tests/models/test_simplemnl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Tests for the SimpleMNL model."""

import shutil

import numpy as np
import tensorflow as tf

from choice_learn.data import ChoiceDataset
from choice_learn.models import SimpleMNL
Expand Down Expand Up @@ -31,7 +34,14 @@ def test_simplemnl_instantiation():

def test_fit_lbfgs():
"""Tests instantiation with item-full and fit with lbfgs."""
model = SimpleMNL(intercept="item-full", optimizer="lbfgs", epochs=20)
tf.config.run_functions_eagerly(True)
model = SimpleMNL(
intercept="item-full",
optimizer="lbfgs",
epochs=20,
regularization="l2",
regularization_strength=0.01,
)
model.instantiate(n_items=3, n_items_features=2, n_shared_features=3)
nll_b = model.evaluate(test_dataset)
model.fit(test_dataset, get_report=True)
Expand All @@ -55,6 +65,7 @@ def test_fit_lbfgs():

def test_fit_adam():
"""Tests instantiation with item and fit with Adam."""
tf.config.run_functions_eagerly(True)
model = SimpleMNL(intercept="item", optimizer="Adam", epochs=100, lr=0.1)
model.instantiate(n_items=3, n_items_features=2, n_shared_features=3)
nll_b = model.evaluate(test_dataset)
Expand All @@ -63,3 +74,43 @@ def test_fit_adam():
assert nll_a < nll_b

assert model.report.to_numpy().shape == (7, 5)


def test_fit_adam_weights():
"""Tests instantiation with item and fit with Adam."""
tf.config.run_functions_eagerly(True)
model = SimpleMNL(
intercept="item",
optimizer="Adam",
epochs=100,
lr=0.1,
regularization="l1",
regularization_strength=0.01,
)
model.instantiate(n_items=3, n_items_features=2, n_shared_features=3)
nll_b = model.evaluate(test_dataset)
model.fit(
test_dataset,
sample_weight=np.array([0.2, 0.4, 0.8, 1.0]),
get_report=True,
val_dataset=test_dataset,
)
nll_a = model.evaluate(test_dataset, batch_size=-1)
nll_c = model.evaluate(test_dataset, batch_size=3)
assert nll_a < nll_b
assert nll_c == nll_a

assert model.report.to_numpy().shape == (7, 5)


def test_save_load():
"""Tests instantiation with item and fit with Adam."""
model = SimpleMNL(intercept="item", optimizer="Adam", epochs=100, lr=0.1)
model.instantiate(n_items=3, n_items_features=2, n_shared_features=3)
nll_b = model.evaluate(test_dataset)
model.save_model("test_save")
loaded_model = SimpleMNL.load_model("test_save")
nll_a = loaded_model.evaluate(test_dataset)

assert nll_a == nll_b
shutil.rmtree("test_save")

0 comments on commit 2ec81cb

Please sign in to comment.