From cda484a61fc3133893907ad8339262e5059cb82d Mon Sep 17 00:00:00 2001 From: Vincent Auriau Date: Fri, 27 Dec 2024 10:59:45 -0800 Subject: [PATCH] ENH: Expectation-Maximization Algorithm (#205) --- choice_learn/models/base_model.py | 3 +- .../models/latent_class_base_model.py | 32 +- choice_learn/models/latent_class_mnl.py | 14 +- notebooks/models/latent_class_model.ipynb | 444 ++++++++++++------ 4 files changed, 342 insertions(+), 151 deletions(-) diff --git a/choice_learn/models/base_model.py b/choice_learn/models/base_model.py index ec6ecd44..d95c0bf5 100644 --- a/choice_learn/models/base_model.py +++ b/choice_learn/models/base_model.py @@ -444,7 +444,7 @@ def fit( self.callbacks.on_train_end(logs=temps_logs) return losses_history - @tf.function + @tf.function(reduce_retracing=True) def batch_predict( self, shared_features_by_choice, @@ -731,7 +731,6 @@ def f(params_1d): # calculate gradients and convert to 1D tf.Tensor grads = tape.gradient(loss_value, self.trainable_weights) grads = tf.dynamic_stitch(idx, grads) - # print out iteration & loss f.iter.assign_add(1) # store loss value so we can retrieve later diff --git a/choice_learn/models/latent_class_base_model.py b/choice_learn/models/latent_class_base_model.py index 2f9321af..0c11645d 100644 --- a/choice_learn/models/latent_class_base_model.py +++ b/choice_learn/models/latent_class_base_model.py @@ -104,12 +104,18 @@ def instantiate(self, **kwargs): name="Latent-Logits", ) self.latent_logits = init_logit - self.models = [self.model_class(**mp) for mp in self.model_parameters] - for model in self.models: - model.instantiate(**kwargs) + self.models = self.instantiate_latent_models(**kwargs) self.instantiated = True + def instantiate_latent_models(self, **kwargs): + """Instantiate latent models.""" + models = [self.model_class(**mp) for mp in self.model_parameters] + for model in models: + model.instantiate(**kwargs) + + return models + # @tf.function def batch_predict( self, @@ -249,7 +255,6 @@ def fit(self, choice_dataset, sample_weight=None, verbose=0): """ if self.fit_method.lower() == "em": self.minf = np.log(1e-3) - print("Expectation-Maximization estimation algorithm not well implemented yet.") return self._em_fit( choice_dataset=choice_dataset, sample_weight=sample_weight, verbose=verbose ) @@ -824,7 +829,7 @@ def _expectation(self, choice_dataset): ) return tf.clip_by_value( - predicted_probas / np.sum(predicted_probas, axis=1, keepdims=True), 1e-10, 1 + predicted_probas / np.sum(predicted_probas, axis=1, keepdims=True), 1e-6, 1 ), loss def _maximization(self, choice_dataset, verbose=0): @@ -842,10 +847,17 @@ def _maximization(self, choice_dataset, verbose=0): np.ndarray latent probabilities resulting of maximization step """ - self.models = [self.model_class(**mp) for mp in self.model_parameters] + # models = [self.model_class(**mp) for mp in self.model_parameters] + # for i in range(len(models)): + # for j, var in enumerate(self.models[i].trainable_weights): + # models[i]._trainable_weights[j] = var + # self.instantiate_latent_models(choice_dataset) + # M-step: MNL estimation for q in range(self.n_latent_classes): - self.models[q].fit(choice_dataset, sample_weight=self.weights[:, q], verbose=verbose) + self.models[q].fit( + choice_dataset, sample_weight=self.weights[:, q].numpy(), verbose=verbose + ) # M-step: latent probability estimation latent_probas = np.sum(self.weights, axis=0) @@ -876,7 +888,9 @@ def _em_fit(self, choice_dataset, sample_weight=None, verbose=0): # Initialization init_sample_weight = np.random.rand(self.n_latent_classes, len(choice_dataset)) - init_sample_weight = init_sample_weight / np.sum(init_sample_weight, axis=0, keepdims=True) + init_sample_weight = np.clip( + init_sample_weight / np.sum(init_sample_weight, axis=0, keepdims=True), 1e-6, 1 + ) for i, model in enumerate(self.models): # model.instantiate() model.fit(choice_dataset, sample_weight=init_sample_weight[i], verbose=verbose) @@ -888,7 +902,7 @@ def _em_fit(self, choice_dataset, sample_weight=None, verbose=0): if np.sum(np.isnan(self.latent_logits)) > 0: print("Nan in logits") break - return hist_logits, hist_loss + return hist_loss, hist_logits def predict_probas(self, choice_dataset, batch_size=-1): """Predicts the choice probabilities for each choice and each product of a ChoiceDataset. diff --git a/choice_learn/models/latent_class_mnl.py b/choice_learn/models/latent_class_mnl.py index 1c5b6b83..c53e0478 100644 --- a/choice_learn/models/latent_class_mnl.py +++ b/choice_learn/models/latent_class_mnl.py @@ -4,6 +4,8 @@ import tensorflow as tf +import choice_learn.tf_ops as tf_ops + from .conditional_logit import ConditionalLogit, MNLCoefficients from .latent_class_base_model import BaseLatentClassModel from .simple_mnl import SimpleMNL @@ -23,6 +25,7 @@ def __init__( intercept=None, optimizer="Adam", lr=0.001, + epochs_maximization=1000, **kwargs, ): """Initialize model. @@ -56,7 +59,7 @@ def __init__( "batch_size": batch_size, "lbfgs_tolerance": lbfgs_tolerance, "lr": lr, - "epochs": 1000, + "epochs": epochs_maximization, } super().__init__( @@ -88,6 +91,15 @@ def instantiate_latent_models(self, n_items, n_shared_features, n_items_features model.indexes, model.weights = model.instantiate( n_items, n_shared_features, n_items_features ) + model.exact_nll = tf_ops.CustomCategoricalCrossEntropy( + from_logits=False, + label_smoothing=0.0, + sparse=False, + axis=-1, + epsilon=1e-25, + name="exact_categorical_crossentropy", + reduction="sum_over_batch_size", + ) model.instantiated = True def instantiate(self, n_items, n_shared_features, n_items_features): diff --git a/notebooks/models/latent_class_model.ipynb b/notebooks/models/latent_class_model.ipynb index 82b4ed8f..70fd8c7b 100644 --- a/notebooks/models/latent_class_model.ipynb +++ b/notebooks/models/latent_class_model.ipynb @@ -120,14 +120,30 @@ "text": [ "Using L-BFGS optimizer, setting up .fit() function\n", "Using L-BFGS optimizer, setting up .fit() function\n", - "Using L-BFGS optimizer, setting up .fit() function\n" + "Using L-BFGS optimizer, setting up .fit() function\n", + "WARNING:tensorflow:5 out of the last 5 calls to .f at 0x7fa18fb41af0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/zz/r1py7zhj35q75v09h8_42nzh0000gp/T/ipykernel_67121/1263996749.py:4: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.\n", + "WARNING:tensorflow:5 out of the last 5 calls to .f at 0x7fa18fb41af0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING:tensorflow:6 out of the last 6 calls to .f at 0x7fa18fb92280> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:tensorflow:6 out of the last 6 calls to .f at 0x7fa18fb92280> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", + "/var/folders/zz/r1py7zhj35q75v09h8_42nzh0000gp/T/ipykernel_27459/1263996749.py:4: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.\n", " cmap = mpl.cm.get_cmap(\"Set1\")\n" ] }, @@ -135,196 +151,196 @@ "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 Latent ClassCoefficient NameCoefficient EstimationStd. Errz_valueP(.>z)Latent ClassCoefficient NameCoefficient EstimationStd. Errz_valueP(.>z)
00Weights_items_features_0-0.6756450.023987-28.1671090.00000000Weights_items_features_0-0.6756450.023987-28.1671090.000000
10Weights_items_features_1-0.0606040.008162-7.4248490.00000010Weights_items_features_1-0.0606040.008162-7.4248490.000000
20Weights_items_features_21.8519510.05491433.7245790.00000020Weights_items_features_21.8519510.05491433.7245790.000000
30Weights_items_features_31.3225490.04815927.4624200.00000030Weights_items_features_31.3225490.04815927.4624200.000000
40Weights_items_features_4-5.8570890.191162-30.6394600.00000040Weights_items_features_4-5.8570890.191162-30.6394600.000000
50Weights_items_features_5-6.5132060.195680-33.2850460.00000050Weights_items_features_5-6.5132060.195680-33.2850460.000000
61Weights_items_features_0-1.8175660.077771-23.3707960.00000061Weights_items_features_0-1.8175660.077771-23.3707960.000000
71Weights_items_features_1-1.7263650.058838-29.3409860.00000071Weights_items_features_1-1.7263650.058838-29.3409860.000000
81Weights_items_features_23.6965670.16025823.0664040.00000081Weights_items_features_23.6965670.16025823.0664040.000000
91Weights_items_features_34.1118400.15717926.1602250.00000091Weights_items_features_34.1118400.15717926.1602250.000000
101Weights_items_features_4-26.6935163.274723-8.1513810.000000101Weights_items_features_4-26.6935163.274723-8.1513810.000000
111Weights_items_features_5-14.9258400.634699-23.5164030.000000111Weights_items_features_5-14.9258400.634699-23.5164030.000000
122Weights_items_features_0-2.1047910.104296-20.1810090.000000122Weights_items_features_0-2.1047910.104296-20.1810090.000000
132Weights_items_features_1-1.6526220.073820-22.3871880.000000132Weights_items_features_1-1.6526220.073820-22.3871880.000000
142Weights_items_features_2-5.5542870.245318-22.6411510.000000142Weights_items_features_2-5.5542870.245318-22.6411510.000000
152Weights_items_features_3-13.5655550.544168-24.9289650.000000152Weights_items_features_3-13.5655550.544168-24.9289650.000000
162Weights_items_features_4-9.7949300.631004-15.5227810.000000162Weights_items_features_4-9.7949300.631004-15.5227810.000000
172Weights_items_features_5-12.1266730.681118-17.8040600.000000172Weights_items_features_5-12.1266730.681118-17.8040600.000000
\n" ], "text/plain": [ - "" + "" ] }, "execution_count": null, @@ -350,6 +366,156 @@ "report.style.apply(format_color_groups, axis=None)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lc_model_2 = LatentClassSimpleMNL(n_latent_classes=3, fit_method=\"EM\", optimizer=\"lbfgs\", epochs=2000, lbfgs_tolerance=1e-6)\n", + "hist, results = lc_model_2.fit(elec_dataset, verbose=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lc_model_2.latent_logits" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.plot(results)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lc_model_2.instantiate(\n", + " n_items=elec_dataset.get_n_items(),\n", + " n_shared_features=elec_dataset.get_n_shared_features(),\n", + " n_items_features=elec_dataset.get_n_items_features(),\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "hist_logits = []\n", + "hist_loss = []\n", + "\n", + "# Initialization\n", + "init_sample_weight = np.random.rand(3, len(elec_dataset))\n", + "init_sample_weight = init_sample_weight / np.sum(init_sample_weight, axis=0, keepdims=True)\n", + "for i, model in enumerate(lc_model_2.models):\n", + " # model.instantiate()\n", + " model.fit(elec_dataset, sample_weight=np.clip(init_sample_weight[i], 1e-4, 1), verbose=2)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lc_model_2.models[2].exact_nll.epsilon" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "init_sample_weight[2].min()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lc_model_2.models[2].predict_probas(elec_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tf.math.log(1.0 * 1e-40)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lc_model_2.models[2]._trainable_weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "myw, myl, myloss = [], [], []\n", + "\n", + "for i in range(10):\n", + " lc_model_2.weights, loss = lc_model_2._expectation(elec_dataset)\n", + " lc_model_2.latent_logits = lc_model_2._maximization(elec_dataset, verbose=2)\n", + "\n", + " myw.append(lc_model_2.weights)\n", + " myl.append(lc_model_2.latent_logits)\n", + " myloss.append(loss)\n", + " if np.sum(np.isnan(self.latent_logits)) > 0:\n", + " print(\"Nan in logits\")\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for i in tqdm.trange(self.epochs):\n", + " self.weights, loss = self._expectation(choice_dataset)\n", + " self.latent_logits = self._maximization(choice_dataset, verbose=verbose)\n", + " hist_logits.append(self.latent_logits)\n", + " hist_loss.append(loss)\n", + " if np.sum(np.isnan(self.latent_logits)) > 0:\n", + " print(\"Nan in logits\")\n", + " break\n", + "return hist_logits, hist_loss" + ] + }, { "cell_type": "markdown", "metadata": {