diff --git a/tests/unit_tests/test_os_datasets.py b/tests/unit_tests/test_os_datasets.py index 1e00e28a..e79fd17f 100644 --- a/tests/unit_tests/test_os_datasets.py +++ b/tests/unit_tests/test_os_datasets.py @@ -17,7 +17,7 @@ def test_swissmetro_loader(): """Test loading the Swissmetro dataset.""" swissmetro = load_swissmetro(as_frame=True) assert isinstance(swissmetro, pd.DataFrame) - assert swissmetro.shape == (10728, 28) + assert swissmetro.shape == (10728, 29) swissmetro = load_swissmetro() assert isinstance(swissmetro, ChoiceDataset) @@ -29,6 +29,7 @@ def test_modecanada_loader(): """Test loading the Canada dataset.""" canada = load_modecanada(as_frame=True) assert isinstance(canada, pd.DataFrame) + assert canada.shape == (15520, 12) canada = load_modecanada() assert isinstance(canada, ChoiceDataset) @@ -38,6 +39,7 @@ def test_electricity_loader(): """Test loading the Electricity dataset.""" electricity = load_electricity(as_frame=True) assert isinstance(electricity, pd.DataFrame) + assert electricity.shape == (17232, 10) electricity = load_electricity() assert isinstance(electricity, ChoiceDataset) @@ -47,6 +49,7 @@ def test_train_loader(): """Test loading the Train dataset.""" train = load_train(as_frame=True) assert isinstance(train, pd.DataFrame) + assert train.shape == (2929, 11) train = load_train() assert isinstance(train, ChoiceDataset) @@ -56,6 +59,7 @@ def test_tafeng_loader(): """Test loading the TaFeng dataset.""" tafeng = load_tafeng(as_frame=True) assert isinstance(tafeng, pd.DataFrame) + assert tafeng.shape == (817741, 9) tafeng = load_tafeng() assert isinstance(tafeng, ChoiceDataset) @@ -65,6 +69,7 @@ def test_heating_loader(): """Test loading the heating dataset.""" heating = load_heating(as_frame=True) assert isinstance(heating, pd.DataFrame) + assert heating.shape == (900, 16) heating = load_heating() assert isinstance(heating, ChoiceDataset) @@ -74,3 +79,4 @@ def test_expedia_loader(): """Test loading the Expedia dataset.""" expedia = load_expedia(as_frame=True) assert isinstance(expedia, pd.DataFrame) + assert expedia.shape == (9917530, 54)