Skip to content

Commit

Permalink
ADD: Latent Class Report (#201)
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentAuriau authored Dec 23, 2024
1 parent c1c59c8 commit d6ebd96
Show file tree
Hide file tree
Showing 3 changed files with 295 additions and 12 deletions.
35 changes: 35 additions & 0 deletions choice_learn/models/latent_class_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time

import numpy as np
import pandas as pd
import tensorflow as tf
import tqdm

Expand Down Expand Up @@ -930,3 +931,37 @@ def get_latent_classes_weights(self):
Latent classes weights/probabilities
"""
return tf.nn.softmax(tf.concat([[tf.constant(0.0)], self.latent_logits], axis=0))

def compute_report(self, choice_dataset):
"""Compute a report of the estimated weights.
Parameters
----------
choice_dataset : ChoiceDataset
ChoiceDataset used for the estimation of the weights that will be
used to compute the Std Err of this estimation.
Returns
-------
pandas.DataFrame
A DF with estimation, Std Err, z_value and p_value for each coefficient.
"""
reports = []
for i, model in enumerate(self.models):
compute = getattr(model, "compute_report", None)
if callable(compute):
report = model.compute_report(choice_dataset)
report["Latent Class"] = i
reports.append(report)
else:
raise ValueError(f"{i}-th model {model} does not have a compute_report method.")
return pd.concat(reports, axis=0, ignore_index=True)[
[
"Latent Class",
"Coefficient Name",
"Coefficient Estimation",
"Std. Err",
"z_value",
"P(.>z)",
]
]
252 changes: 249 additions & 3 deletions notebooks/models/latent_class_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
"\n",
"import matplotlib as mpl\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
Expand Down Expand Up @@ -106,9 +107,254 @@
"print(f\"Negative Log-Likelihood: {nll}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"keep_output": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using L-BFGS optimizer, setting up .fit() function\n",
"Using L-BFGS optimizer, setting up .fit() function\n",
"Using L-BFGS optimizer, setting up .fit() function\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/zz/r1py7zhj35q75v09h8_42nzh0000gp/T/ipykernel_67121/1263996749.py:4: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.\n",
" cmap = mpl.cm.get_cmap(\"Set1\")\n"
]
},
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_279d5_row0_col0, #T_279d5_row0_col1, #T_279d5_row0_col2, #T_279d5_row0_col3, #T_279d5_row0_col4, #T_279d5_row0_col5, #T_279d5_row1_col0, #T_279d5_row1_col1, #T_279d5_row1_col2, #T_279d5_row1_col3, #T_279d5_row1_col4, #T_279d5_row1_col5, #T_279d5_row2_col0, #T_279d5_row2_col1, #T_279d5_row2_col2, #T_279d5_row2_col3, #T_279d5_row2_col4, #T_279d5_row2_col5, #T_279d5_row3_col0, #T_279d5_row3_col1, #T_279d5_row3_col2, #T_279d5_row3_col3, #T_279d5_row3_col4, #T_279d5_row3_col5, #T_279d5_row4_col0, #T_279d5_row4_col1, #T_279d5_row4_col2, #T_279d5_row4_col3, #T_279d5_row4_col4, #T_279d5_row4_col5, #T_279d5_row5_col0, #T_279d5_row5_col1, #T_279d5_row5_col2, #T_279d5_row5_col3, #T_279d5_row5_col4, #T_279d5_row5_col5 {\n",
" background-color: #e41a1c;\n",
"}\n",
"#T_279d5_row6_col0, #T_279d5_row6_col1, #T_279d5_row6_col2, #T_279d5_row6_col3, #T_279d5_row6_col4, #T_279d5_row6_col5, #T_279d5_row7_col0, #T_279d5_row7_col1, #T_279d5_row7_col2, #T_279d5_row7_col3, #T_279d5_row7_col4, #T_279d5_row7_col5, #T_279d5_row8_col0, #T_279d5_row8_col1, #T_279d5_row8_col2, #T_279d5_row8_col3, #T_279d5_row8_col4, #T_279d5_row8_col5, #T_279d5_row9_col0, #T_279d5_row9_col1, #T_279d5_row9_col2, #T_279d5_row9_col3, #T_279d5_row9_col4, #T_279d5_row9_col5, #T_279d5_row10_col0, #T_279d5_row10_col1, #T_279d5_row10_col2, #T_279d5_row10_col3, #T_279d5_row10_col4, #T_279d5_row10_col5, #T_279d5_row11_col0, #T_279d5_row11_col1, #T_279d5_row11_col2, #T_279d5_row11_col3, #T_279d5_row11_col4, #T_279d5_row11_col5 {\n",
" background-color: #377eb8;\n",
"}\n",
"#T_279d5_row12_col0, #T_279d5_row12_col1, #T_279d5_row12_col2, #T_279d5_row12_col3, #T_279d5_row12_col4, #T_279d5_row12_col5, #T_279d5_row13_col0, #T_279d5_row13_col1, #T_279d5_row13_col2, #T_279d5_row13_col3, #T_279d5_row13_col4, #T_279d5_row13_col5, #T_279d5_row14_col0, #T_279d5_row14_col1, #T_279d5_row14_col2, #T_279d5_row14_col3, #T_279d5_row14_col4, #T_279d5_row14_col5, #T_279d5_row15_col0, #T_279d5_row15_col1, #T_279d5_row15_col2, #T_279d5_row15_col3, #T_279d5_row15_col4, #T_279d5_row15_col5, #T_279d5_row16_col0, #T_279d5_row16_col1, #T_279d5_row16_col2, #T_279d5_row16_col3, #T_279d5_row16_col4, #T_279d5_row16_col5, #T_279d5_row17_col0, #T_279d5_row17_col1, #T_279d5_row17_col2, #T_279d5_row17_col3, #T_279d5_row17_col4, #T_279d5_row17_col5 {\n",
" background-color: #4daf4a;\n",
"}\n",
"</style>\n",
"<table id=\"T_279d5\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_279d5_level0_col0\" class=\"col_heading level0 col0\" >Latent Class</th>\n",
" <th id=\"T_279d5_level0_col1\" class=\"col_heading level0 col1\" >Coefficient Name</th>\n",
" <th id=\"T_279d5_level0_col2\" class=\"col_heading level0 col2\" >Coefficient Estimation</th>\n",
" <th id=\"T_279d5_level0_col3\" class=\"col_heading level0 col3\" >Std. Err</th>\n",
" <th id=\"T_279d5_level0_col4\" class=\"col_heading level0 col4\" >z_value</th>\n",
" <th id=\"T_279d5_level0_col5\" class=\"col_heading level0 col5\" >P(.>z)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n",
" <td id=\"T_279d5_row0_col0\" class=\"data row0 col0\" >0</td>\n",
" <td id=\"T_279d5_row0_col1\" class=\"data row0 col1\" >Weights_items_features_0</td>\n",
" <td id=\"T_279d5_row0_col2\" class=\"data row0 col2\" >-0.675645</td>\n",
" <td id=\"T_279d5_row0_col3\" class=\"data row0 col3\" >0.023987</td>\n",
" <td id=\"T_279d5_row0_col4\" class=\"data row0 col4\" >-28.167109</td>\n",
" <td id=\"T_279d5_row0_col5\" class=\"data row0 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row1\" class=\"row_heading level0 row1\" >1</th>\n",
" <td id=\"T_279d5_row1_col0\" class=\"data row1 col0\" >0</td>\n",
" <td id=\"T_279d5_row1_col1\" class=\"data row1 col1\" >Weights_items_features_1</td>\n",
" <td id=\"T_279d5_row1_col2\" class=\"data row1 col2\" >-0.060604</td>\n",
" <td id=\"T_279d5_row1_col3\" class=\"data row1 col3\" >0.008162</td>\n",
" <td id=\"T_279d5_row1_col4\" class=\"data row1 col4\" >-7.424849</td>\n",
" <td id=\"T_279d5_row1_col5\" class=\"data row1 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row2\" class=\"row_heading level0 row2\" >2</th>\n",
" <td id=\"T_279d5_row2_col0\" class=\"data row2 col0\" >0</td>\n",
" <td id=\"T_279d5_row2_col1\" class=\"data row2 col1\" >Weights_items_features_2</td>\n",
" <td id=\"T_279d5_row2_col2\" class=\"data row2 col2\" >1.851951</td>\n",
" <td id=\"T_279d5_row2_col3\" class=\"data row2 col3\" >0.054914</td>\n",
" <td id=\"T_279d5_row2_col4\" class=\"data row2 col4\" >33.724579</td>\n",
" <td id=\"T_279d5_row2_col5\" class=\"data row2 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row3\" class=\"row_heading level0 row3\" >3</th>\n",
" <td id=\"T_279d5_row3_col0\" class=\"data row3 col0\" >0</td>\n",
" <td id=\"T_279d5_row3_col1\" class=\"data row3 col1\" >Weights_items_features_3</td>\n",
" <td id=\"T_279d5_row3_col2\" class=\"data row3 col2\" >1.322549</td>\n",
" <td id=\"T_279d5_row3_col3\" class=\"data row3 col3\" >0.048159</td>\n",
" <td id=\"T_279d5_row3_col4\" class=\"data row3 col4\" >27.462420</td>\n",
" <td id=\"T_279d5_row3_col5\" class=\"data row3 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row4\" class=\"row_heading level0 row4\" >4</th>\n",
" <td id=\"T_279d5_row4_col0\" class=\"data row4 col0\" >0</td>\n",
" <td id=\"T_279d5_row4_col1\" class=\"data row4 col1\" >Weights_items_features_4</td>\n",
" <td id=\"T_279d5_row4_col2\" class=\"data row4 col2\" >-5.857089</td>\n",
" <td id=\"T_279d5_row4_col3\" class=\"data row4 col3\" >0.191162</td>\n",
" <td id=\"T_279d5_row4_col4\" class=\"data row4 col4\" >-30.639460</td>\n",
" <td id=\"T_279d5_row4_col5\" class=\"data row4 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row5\" class=\"row_heading level0 row5\" >5</th>\n",
" <td id=\"T_279d5_row5_col0\" class=\"data row5 col0\" >0</td>\n",
" <td id=\"T_279d5_row5_col1\" class=\"data row5 col1\" >Weights_items_features_5</td>\n",
" <td id=\"T_279d5_row5_col2\" class=\"data row5 col2\" >-6.513206</td>\n",
" <td id=\"T_279d5_row5_col3\" class=\"data row5 col3\" >0.195680</td>\n",
" <td id=\"T_279d5_row5_col4\" class=\"data row5 col4\" >-33.285046</td>\n",
" <td id=\"T_279d5_row5_col5\" class=\"data row5 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row6\" class=\"row_heading level0 row6\" >6</th>\n",
" <td id=\"T_279d5_row6_col0\" class=\"data row6 col0\" >1</td>\n",
" <td id=\"T_279d5_row6_col1\" class=\"data row6 col1\" >Weights_items_features_0</td>\n",
" <td id=\"T_279d5_row6_col2\" class=\"data row6 col2\" >-1.817566</td>\n",
" <td id=\"T_279d5_row6_col3\" class=\"data row6 col3\" >0.077771</td>\n",
" <td id=\"T_279d5_row6_col4\" class=\"data row6 col4\" >-23.370796</td>\n",
" <td id=\"T_279d5_row6_col5\" class=\"data row6 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row7\" class=\"row_heading level0 row7\" >7</th>\n",
" <td id=\"T_279d5_row7_col0\" class=\"data row7 col0\" >1</td>\n",
" <td id=\"T_279d5_row7_col1\" class=\"data row7 col1\" >Weights_items_features_1</td>\n",
" <td id=\"T_279d5_row7_col2\" class=\"data row7 col2\" >-1.726365</td>\n",
" <td id=\"T_279d5_row7_col3\" class=\"data row7 col3\" >0.058838</td>\n",
" <td id=\"T_279d5_row7_col4\" class=\"data row7 col4\" >-29.340986</td>\n",
" <td id=\"T_279d5_row7_col5\" class=\"data row7 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row8\" class=\"row_heading level0 row8\" >8</th>\n",
" <td id=\"T_279d5_row8_col0\" class=\"data row8 col0\" >1</td>\n",
" <td id=\"T_279d5_row8_col1\" class=\"data row8 col1\" >Weights_items_features_2</td>\n",
" <td id=\"T_279d5_row8_col2\" class=\"data row8 col2\" >3.696567</td>\n",
" <td id=\"T_279d5_row8_col3\" class=\"data row8 col3\" >0.160258</td>\n",
" <td id=\"T_279d5_row8_col4\" class=\"data row8 col4\" >23.066404</td>\n",
" <td id=\"T_279d5_row8_col5\" class=\"data row8 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row9\" class=\"row_heading level0 row9\" >9</th>\n",
" <td id=\"T_279d5_row9_col0\" class=\"data row9 col0\" >1</td>\n",
" <td id=\"T_279d5_row9_col1\" class=\"data row9 col1\" >Weights_items_features_3</td>\n",
" <td id=\"T_279d5_row9_col2\" class=\"data row9 col2\" >4.111840</td>\n",
" <td id=\"T_279d5_row9_col3\" class=\"data row9 col3\" >0.157179</td>\n",
" <td id=\"T_279d5_row9_col4\" class=\"data row9 col4\" >26.160225</td>\n",
" <td id=\"T_279d5_row9_col5\" class=\"data row9 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row10\" class=\"row_heading level0 row10\" >10</th>\n",
" <td id=\"T_279d5_row10_col0\" class=\"data row10 col0\" >1</td>\n",
" <td id=\"T_279d5_row10_col1\" class=\"data row10 col1\" >Weights_items_features_4</td>\n",
" <td id=\"T_279d5_row10_col2\" class=\"data row10 col2\" >-26.693516</td>\n",
" <td id=\"T_279d5_row10_col3\" class=\"data row10 col3\" >3.274723</td>\n",
" <td id=\"T_279d5_row10_col4\" class=\"data row10 col4\" >-8.151381</td>\n",
" <td id=\"T_279d5_row10_col5\" class=\"data row10 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row11\" class=\"row_heading level0 row11\" >11</th>\n",
" <td id=\"T_279d5_row11_col0\" class=\"data row11 col0\" >1</td>\n",
" <td id=\"T_279d5_row11_col1\" class=\"data row11 col1\" >Weights_items_features_5</td>\n",
" <td id=\"T_279d5_row11_col2\" class=\"data row11 col2\" >-14.925840</td>\n",
" <td id=\"T_279d5_row11_col3\" class=\"data row11 col3\" >0.634699</td>\n",
" <td id=\"T_279d5_row11_col4\" class=\"data row11 col4\" >-23.516403</td>\n",
" <td id=\"T_279d5_row11_col5\" class=\"data row11 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row12\" class=\"row_heading level0 row12\" >12</th>\n",
" <td id=\"T_279d5_row12_col0\" class=\"data row12 col0\" >2</td>\n",
" <td id=\"T_279d5_row12_col1\" class=\"data row12 col1\" >Weights_items_features_0</td>\n",
" <td id=\"T_279d5_row12_col2\" class=\"data row12 col2\" >-2.104791</td>\n",
" <td id=\"T_279d5_row12_col3\" class=\"data row12 col3\" >0.104296</td>\n",
" <td id=\"T_279d5_row12_col4\" class=\"data row12 col4\" >-20.181009</td>\n",
" <td id=\"T_279d5_row12_col5\" class=\"data row12 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row13\" class=\"row_heading level0 row13\" >13</th>\n",
" <td id=\"T_279d5_row13_col0\" class=\"data row13 col0\" >2</td>\n",
" <td id=\"T_279d5_row13_col1\" class=\"data row13 col1\" >Weights_items_features_1</td>\n",
" <td id=\"T_279d5_row13_col2\" class=\"data row13 col2\" >-1.652622</td>\n",
" <td id=\"T_279d5_row13_col3\" class=\"data row13 col3\" >0.073820</td>\n",
" <td id=\"T_279d5_row13_col4\" class=\"data row13 col4\" >-22.387188</td>\n",
" <td id=\"T_279d5_row13_col5\" class=\"data row13 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row14\" class=\"row_heading level0 row14\" >14</th>\n",
" <td id=\"T_279d5_row14_col0\" class=\"data row14 col0\" >2</td>\n",
" <td id=\"T_279d5_row14_col1\" class=\"data row14 col1\" >Weights_items_features_2</td>\n",
" <td id=\"T_279d5_row14_col2\" class=\"data row14 col2\" >-5.554287</td>\n",
" <td id=\"T_279d5_row14_col3\" class=\"data row14 col3\" >0.245318</td>\n",
" <td id=\"T_279d5_row14_col4\" class=\"data row14 col4\" >-22.641151</td>\n",
" <td id=\"T_279d5_row14_col5\" class=\"data row14 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row15\" class=\"row_heading level0 row15\" >15</th>\n",
" <td id=\"T_279d5_row15_col0\" class=\"data row15 col0\" >2</td>\n",
" <td id=\"T_279d5_row15_col1\" class=\"data row15 col1\" >Weights_items_features_3</td>\n",
" <td id=\"T_279d5_row15_col2\" class=\"data row15 col2\" >-13.565555</td>\n",
" <td id=\"T_279d5_row15_col3\" class=\"data row15 col3\" >0.544168</td>\n",
" <td id=\"T_279d5_row15_col4\" class=\"data row15 col4\" >-24.928965</td>\n",
" <td id=\"T_279d5_row15_col5\" class=\"data row15 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row16\" class=\"row_heading level0 row16\" >16</th>\n",
" <td id=\"T_279d5_row16_col0\" class=\"data row16 col0\" >2</td>\n",
" <td id=\"T_279d5_row16_col1\" class=\"data row16 col1\" >Weights_items_features_4</td>\n",
" <td id=\"T_279d5_row16_col2\" class=\"data row16 col2\" >-9.794930</td>\n",
" <td id=\"T_279d5_row16_col3\" class=\"data row16 col3\" >0.631004</td>\n",
" <td id=\"T_279d5_row16_col4\" class=\"data row16 col4\" >-15.522781</td>\n",
" <td id=\"T_279d5_row16_col5\" class=\"data row16 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row17\" class=\"row_heading level0 row17\" >17</th>\n",
" <td id=\"T_279d5_row17_col0\" class=\"data row17 col0\" >2</td>\n",
" <td id=\"T_279d5_row17_col1\" class=\"data row17 col1\" >Weights_items_features_5</td>\n",
" <td id=\"T_279d5_row17_col2\" class=\"data row17 col2\" >-12.126673</td>\n",
" <td id=\"T_279d5_row17_col3\" class=\"data row17 col3\" >0.681118</td>\n",
" <td id=\"T_279d5_row17_col4\" class=\"data row17 col4\" >-17.804060</td>\n",
" <td id=\"T_279d5_row17_col5\" class=\"data row17 col5\" >0.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7f7d56306970>"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"report = lc_model.compute_report(elec_dataset)\n",
"\n",
"def format_color_groups(df):\n",
" cmap = mpl.cm.get_cmap(\"Set1\")\n",
" colors = [mpl.colors.rgb2hex(cmap(i)) for i in range(cmap.N)]\n",
" x = df.copy()\n",
" factors = list(x['Latent Class'].unique())\n",
" i = 0\n",
" for factor in factors:\n",
" style = f'background-color: {colors[i]}'\n",
" x.loc[x['Latent Class'] == factor, :] = style\n",
" i += 1\n",
" return x\n",
"\n",
"report.style.apply(format_color_groups, axis=None)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"keep_output": true
},
"source": [
"## Latent Conditional Logit\n",
"We used a very simple MNL. Here we simulate the same MNL, by using the Conditional-Logit formulation.\\\n",
Expand Down Expand Up @@ -281,7 +527,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "tf_env",
"display_name": "choice_learn",
"language": "python",
"name": "python3"
},
Expand All @@ -295,7 +541,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.8.18"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit d6ebd96

Please sign in to comment.