Skip to content

Commit

Permalink
fix duplicates in dataset samples, improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SuryaThiru committed Apr 29, 2021
1 parent 0622aa2 commit fcdab7d
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 5 deletions.
2 changes: 1 addition & 1 deletion mlgauge/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.2
0.3.3
8 changes: 5 additions & 3 deletions mlgauge/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,14 +287,16 @@ def _precheck_methods(self, methods):
def _expand_dataset_str(self, dataset_str, n_datasets):
"""Convert the dataset string to list of pmlb dataset names"""
if dataset_str == "all":
datasets = self.random_state.choice(pmlb.dataset_names, n_datasets)
datasets = self.random_state.choice(
pmlb.dataset_names, n_datasets, replace=False
)
elif dataset_str == "classification":
datasets = self.random_state.choice(
pmlb.classification_dataset_names, n_datasets
pmlb.classification_dataset_names, n_datasets, replace=False
)
elif dataset_str == "regression":
datasets = self.random_state.choice(
pmlb.regression_dataset_names, n_datasets
pmlb.regression_dataset_names, n_datasets, replace=False
)
return datasets

Expand Down
46 changes: 45 additions & 1 deletion tests/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ def regressor():


# Test all allowed data formats
class MockMethodFormat(Method):
"""Class to check if the input format is correct."""

def __init__(self):
super().__init__()

def train(self, X, y, feature_names, category_indicator=None):
assert isinstance(X, np.ndarray), "Incorrect input type for X"
assert isinstance(y, np.ndarray), "Incorrect input type for y"


class TestDataFormat:
def test_string(self, regressor, tmp_path):
# should work with "all", "classification", "regression"
Expand Down Expand Up @@ -109,6 +120,37 @@ def test_openml_list(self, regressor, tmp_path):
assert len(an.datasets) == 3
an.run()

def test_array_input(self, regressor, tmp_path):
an = Analysis(
methods=[("dummy", MockMethodFormat())],
metric_names=["r2", "max_error"],
datasets="classification",
n_datasets=3,
data_source="pmlb",
random_state=SEED,
output_dir=tmp_path,
local_cache_dir=PMLB_CACHE,
use_test_set=False,
disable_progress=True,
)
an.run()
an = Analysis(
methods=[("dummy", MockMethodFormat())],
metric_names=["r2", "max_error"],
datasets=[
"wind",
"iris",
1030,
], # sometimes failes due to issues with busy openml servers
data_source="openml",
random_state=SEED,
output_dir=tmp_path,
local_cache_dir=PMLB_CACHE,
use_test_set=False,
disable_progress=True,
)
an.run()

def test_tuple(self, regressor, tmp_path):
# should works with a list of (X, y) tuples
datasets = [
Expand Down Expand Up @@ -439,7 +481,9 @@ def test_result_cv(tmp_path):
tree.set_test_set(False)
# check if the results match
for data in an.datasets:
X, y = pmlb.fetch_data(data, return_X_y=True)
X, y = pmlb.fetch_data(
data, return_X_y=True, local_cache_dir=PMLB_CACHE, dropna=False
)

linear_r2, linear_max = linear.train(X, y)
tree_r2, tree_max = tree.train(X, y)
Expand Down

0 comments on commit fcdab7d

Please sign in to comment.