Skip to content

Commit

Permalink
ADD: EM tests & more
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentAuriau committed Dec 27, 2024
1 parent e15d9ef commit f8c8297
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
21 changes: 21 additions & 0 deletions tests/integration_tests/models/test_latent_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def test_latent_simple_mnl():
_, _ = lc_model.fit(elec_dataset)
lc_model.compute_report(elec_dataset)

probas = lc_model.predict_modelwise_probas(elec_dataset)
assert probas.shape == (2, len(elec_dataset), 4)
probas = lc_model.predict_probas(elec_dataset)
assert probas.shape == (len(elec_dataset), 4)

assert lc_model.evaluate(elec_dataset).numpy() < 1.15


Expand Down Expand Up @@ -85,3 +90,19 @@ def test_manual_lc_gd():
nll_before = manual_lc.evaluate(elec_dataset)
_ = manual_lc.fit(elec_dataset)
assert manual_lc.evaluate(elec_dataset) < nll_before


def test_em_fit():
"""Test EM algorithm to estimate Latent Class Model."""
lc_model_em = LatentClassSimpleMNL(
n_latent_classes=3, fit_method="EM", optimizer="lbfgs", epochs=15, lbfgs_tolerance=1e-6
)
lc_model_em.instantiate(
n_items=elec_dataset.get_n_items(),
n_shared_features=elec_dataset.get_n_shared_features(),
n_items_features=elec_dataset.get_n_items_features(),
)
nll_b = lc_model_em.evaluate(elec_dataset)
_, _ = lc_model_em.fit(elec_dataset, verbose=0)
nll_a = lc_model_em.evaluate(elec_dataset)
assert nll_a < nll_b
4 changes: 2 additions & 2 deletions tests/unit_tests/models/test_rumnet_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,8 @@ def test_gpu_rumnet():
depth_u=3,
tol=1e-5,
optimizer="adam",
lr=0.01,
epochs=5,
lr=0.001,
epochs=10,
)
model.instantiate()
assert model.batch_predict(
Expand Down

0 comments on commit f8c8297

Please sign in to comment.