diff --git a/mlgauge/VERSION b/mlgauge/VERSION index d15723f..1c09c74 100644 --- a/mlgauge/VERSION +++ b/mlgauge/VERSION @@ -1 +1 @@ -0.3.2 +0.3.3 diff --git a/mlgauge/analysis.py b/mlgauge/analysis.py index c058ba5..1cacf0a 100644 --- a/mlgauge/analysis.py +++ b/mlgauge/analysis.py @@ -287,14 +287,16 @@ def _precheck_methods(self, methods): def _expand_dataset_str(self, dataset_str, n_datasets): """Convert the dataset string to list of pmlb dataset names""" if dataset_str == "all": - datasets = self.random_state.choice(pmlb.dataset_names, n_datasets) + datasets = self.random_state.choice( + pmlb.dataset_names, n_datasets, replace=False + ) elif dataset_str == "classification": datasets = self.random_state.choice( - pmlb.classification_dataset_names, n_datasets + pmlb.classification_dataset_names, n_datasets, replace=False ) elif dataset_str == "regression": datasets = self.random_state.choice( - pmlb.regression_dataset_names, n_datasets + pmlb.regression_dataset_names, n_datasets, replace=False ) return datasets diff --git a/tests/test_analysis.py b/tests/test_analysis.py index ad992b7..4f26578 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -26,6 +26,17 @@ def regressor(): # Test all allowed data formats +class MockMethodFormat(Method): + """Class to check if the input format is correct.""" + + def __init__(self): + super().__init__() + + def train(self, X, y, feature_names, category_indicator=None): + assert isinstance(X, np.ndarray), "Incorrect input type for X" + assert isinstance(y, np.ndarray), "Incorrect input type for y" + + class TestDataFormat: def test_string(self, regressor, tmp_path): # should work with "all", "classification", "regression" @@ -109,6 +120,37 @@ def test_openml_list(self, regressor, tmp_path): assert len(an.datasets) == 3 an.run() + def test_array_input(self, regressor, tmp_path): + an = Analysis( + methods=[("dummy", MockMethodFormat())], + metric_names=["r2", "max_error"], + datasets="classification", + n_datasets=3, + data_source="pmlb", + random_state=SEED, + output_dir=tmp_path, + local_cache_dir=PMLB_CACHE, + use_test_set=False, + disable_progress=True, + ) + an.run() + an = Analysis( + methods=[("dummy", MockMethodFormat())], + metric_names=["r2", "max_error"], + datasets=[ + "wind", + "iris", + 1030, + ], # sometimes failes due to issues with busy openml servers + data_source="openml", + random_state=SEED, + output_dir=tmp_path, + local_cache_dir=PMLB_CACHE, + use_test_set=False, + disable_progress=True, + ) + an.run() + def test_tuple(self, regressor, tmp_path): # should works with a list of (X, y) tuples datasets = [ @@ -439,7 +481,9 @@ def test_result_cv(tmp_path): tree.set_test_set(False) # check if the results match for data in an.datasets: - X, y = pmlb.fetch_data(data, return_X_y=True) + X, y = pmlb.fetch_data( + data, return_X_y=True, local_cache_dir=PMLB_CACHE, dropna=False + ) linear_r2, linear_max = linear.train(X, y) tree_r2, tree_max = tree.train(X, y)