From f8c829739e643c46923e530e883d1082f13ad792 Mon Sep 17 00:00:00 2001 From: VincentAURIAU Date: Fri, 27 Dec 2024 19:58:27 +0100 Subject: [PATCH] ADD: EM tests & more --- .../models/test_latent_class.py | 21 +++++++++++++++++++ tests/unit_tests/models/test_rumnet_unit.py | 4 ++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/tests/integration_tests/models/test_latent_class.py b/tests/integration_tests/models/test_latent_class.py index 1c49cf63..675ad444 100644 --- a/tests/integration_tests/models/test_latent_class.py +++ b/tests/integration_tests/models/test_latent_class.py @@ -23,6 +23,11 @@ def test_latent_simple_mnl(): _, _ = lc_model.fit(elec_dataset) lc_model.compute_report(elec_dataset) + probas = lc_model.predict_modelwise_probas(elec_dataset) + assert probas.shape == (2, len(elec_dataset), 4) + probas = lc_model.predict_probas(elec_dataset) + assert probas.shape == (len(elec_dataset), 4) + assert lc_model.evaluate(elec_dataset).numpy() < 1.15 @@ -85,3 +90,19 @@ def test_manual_lc_gd(): nll_before = manual_lc.evaluate(elec_dataset) _ = manual_lc.fit(elec_dataset) assert manual_lc.evaluate(elec_dataset) < nll_before + + +def test_em_fit(): + """Test EM algorithm to estimate Latent Class Model.""" + lc_model_em = LatentClassSimpleMNL( + n_latent_classes=3, fit_method="EM", optimizer="lbfgs", epochs=15, lbfgs_tolerance=1e-6 + ) + lc_model_em.instantiate( + n_items=elec_dataset.get_n_items(), + n_shared_features=elec_dataset.get_n_shared_features(), + n_items_features=elec_dataset.get_n_items_features(), + ) + nll_b = lc_model_em.evaluate(elec_dataset) + _, _ = lc_model_em.fit(elec_dataset, verbose=0) + nll_a = lc_model_em.evaluate(elec_dataset) + assert nll_a < nll_b diff --git a/tests/unit_tests/models/test_rumnet_unit.py b/tests/unit_tests/models/test_rumnet_unit.py index 94578b62..61993fd8 100644 --- a/tests/unit_tests/models/test_rumnet_unit.py +++ b/tests/unit_tests/models/test_rumnet_unit.py @@ -254,8 +254,8 @@ def test_gpu_rumnet(): depth_u=3, tol=1e-5, optimizer="adam", - lr=0.01, - epochs=5, + lr=0.001, + epochs=10, ) model.instantiate() assert model.batch_predict(