Skip to content

Commit

Permalink
ADD: Data tests (#207)
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentAuriau authored Dec 28, 2024
1 parent b2bcc5e commit 9928d46
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 3 deletions.
7 changes: 5 additions & 2 deletions choice_learn/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def load_modecanada(
if row.choice == 1:
named_choice[n_row - 1] = row.alt

canada_df["choice"] = named_choice
canada_df["named_choice"] = named_choice

if as_frame:
if split_features:
Expand All @@ -667,6 +667,9 @@ def load_modecanada(
items_features_by_choice,
choices,
)
if choice_format == "items_id":
canada_df["choice"] = canada_df["named_choice"]
canada_df = canada_df.drop("named_choice", axis=1)
return canada_df

if split_features:
Expand All @@ -687,7 +690,7 @@ def load_modecanada(
cf.append(context_df.loc[context_df.alt == item][items_features].to_numpy()[0])
cav.append(1)
else:
cf.append([0.0, 0.0, 0.0, 0.0])
cf.append([0.0 for _ in range(len(items_features))])
cav.append(0)
cif.append(cf)
ci_av.append(cav)
Expand Down
21 changes: 21 additions & 0 deletions tests/data/test_data.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
,srch_id,prop_id,site_id
0,1,893,12
1,1,10404,12
2,1,21315,12
3,1,27348,12
4,1,29604,12
5,1,30184,12
6,1,44147,12
7,1,50984,12
8,1,53341,12
9,1,56880,12
10,1,59267,12
11,1,59526,12
12,1,68914,12
13,1,74474,12
14,1,81437,12
15,1,85728,12
16,1,88096,12
17,1,88127,12
18,1,88218,12
19,1,89073,12
11 changes: 11 additions & 0 deletions tests/unit_tests/datasets/test_expedia.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Unit testing for Expedia loader."""

import pytest

from choice_learn.datasets import load_expedia


def test_raise_filenotfound():
"""Test that error raised if no file exist."""
with pytest.raises(FileNotFoundError):
load_expedia()
81 changes: 80 additions & 1 deletion tests/unit_tests/test_os_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
load_tafeng,
load_train,
)
from choice_learn.datasets.base import load_csv, load_gzip, slice_from_names


def test_swissmetro_loader():
Expand All @@ -29,15 +30,73 @@ def test_swissmetro_loader():
assert isinstance(swissmetro, ChoiceDataset)


def test_swissmetro_long_format():
"""Test loading the Swissmetro dataset in long format."""
swissmetro = load_swissmetro(as_frame=True, preprocessing="long_format")
assert isinstance(swissmetro, pd.DataFrame)
assert swissmetro.shape == (30474, 7)


def test_swissmetro_tastenet():
"""Test TasteNet preprocessing of dataset."""
_ = load_swissmetro(preprocessing="tastenet")


def test_swissmetro_tutorial():
"""Test tutorial preprocessing of dataset."""
_ = load_swissmetro(preprocessing="tutorial")


def test_biogeme_nested_tutorial():
"""Test biogeme_nested preprocessing of dataset."""
_ = load_swissmetro(preprocessing="biogeme_nested")


def test_rumnet_tutorial():
"""Test rumnet preprocessing of dataset."""
_ = load_swissmetro(preprocessing="rumnet")


def test_modecanada_loader():
"""Test loading the Canada dataset."""
canada = load_modecanada(as_frame=True)
canada = load_modecanada(as_frame=True, choice_format="items_id")
assert isinstance(canada, pd.DataFrame)
assert canada.shape == (15520, 11)

canada = load_modecanada()
assert isinstance(canada, ChoiceDataset)

ca, na, da = load_modecanada(
as_frame=True,
add_items_one_hot=True,
add_is_public=True,
choice_format="items_id",
split_features=True,
)
assert ca.shape == (4324, 4)
assert na.shape == (15520, 11)
assert da.shape == (4324, 2)


def test_modecanada_features_split():
"""Test that features are split well."""
(
o,
ca,
na,
da,
) = load_modecanada(add_items_one_hot=True, add_is_public=True, split_features=True)
assert o.shape == (4324, 3)
assert ca.shape == (4324, 4, 9)
assert na.shape == (4324, 4)
assert da.shape == (4324,)


def test_modecanada_loader_2():
"""Test loading the Canada dataset w/ preprocessing."""
canada = load_modecanada(preprocessing="tutorial", add_items_one_hot=True)
assert isinstance(canada, ChoiceDataset)


def test_electricity_loader():
"""Test loading the Electricity dataset."""
Expand Down Expand Up @@ -324,3 +383,23 @@ def test_londonpassenger_loader():
"distance",
]
assert londonpassenger.shared_features_by_choice_names[0] == expected_shared_features_names


def test_description():
"""Test getting description."""
_ = load_swissmetro(return_desc=True)
_ = load_modecanada(return_desc=True)
_ = load_heating(return_desc=True)
_ = load_electricity(return_desc=True)
_ = load_train(return_desc=True)
_ = load_car_preferences(return_desc=True)
_ = load_hc(return_desc=True)
_ = load_londonpassenger(return_desc=True)
_ = load_tafeng(return_desc=True)


def test_load_csv():
"""Test csv file loader."""
_ = load_csv(data_file_name="test_data.csv", data_module="tests/data")
names, data = load_gzip("swissmetro.csv.gz", data_module="choice_learn/datasets/data")
_ = slice_from_names(data, names[:4], names)

0 comments on commit 9928d46

Please sign in to comment.