Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support test_io_data for MultiDataEval and eval util #84

Merged
merged 3 commits into from
Feb 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions src/sensai/evaluation/eval_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ def create_vector_model_cross_validator(data: InputOutputData,

def create_evaluation_util(data: InputOutputData, model: VectorModel = None, is_regression: bool = None,
evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams]] = None,
cross_validator_params: Optional[Dict[str, Any]] = None) \
cross_validator_params: Optional[Dict[str, Any]] = None, test_io_data: Optional[InputOutputData] = None) \
-> Union["ClassificationModelEvaluation", "RegressionModelEvaluation"]:
if _is_regression(model, is_regression):
return RegressionModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params)
return RegressionModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params, test_io_data=test_io_data)
else:
return ClassificationModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params)
return ClassificationModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params, test_io_data=test_io_data)


def eval_model_via_evaluator(model: TModel, io_data: InputOutputData, test_fraction=0.2,
Expand Down Expand Up @@ -576,16 +576,34 @@ class MultiDataModelEvaluation:
def __init__(self, io_data_dict: Dict[str, InputOutputData], key_name: str = "dataset",
meta_data_dict: Optional[Dict[str, Dict[str, Any]]] = None,
evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams, Dict[str, Any]]] = None,
cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None):
cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None,
test_io_data_dict: Optional[Dict[str, Optional[InputOutputData]]] = None):
"""
:param io_data_dict: a dictionary mapping from names to the data sets with which to evaluate models
:param io_data_dict: a dictionary mapping from names to the data sets with which to evaluate models.
For evaluation or cross-validation, these datasets will usually be split according to the rules
specified by `evaluator_params or `cross_validator_params`. An exception is the case where
explicit test data sets are specified by passing `test_io_data_dict`. Then, for these data
sets, the io_data will not be split for evaluation, but the test_io_data will be used instead.
:param key_name: a name for the key value used in inputOutputDataDict, which will be used as a column name in result data frames
:param meta_data_dict: a dictionary which maps from a name (same keys as in inputOutputDataDict) to a dictionary, which maps
from a column name to a value and which is to be used to extend the result data frames containing per-dataset results
:param evaluator_params: parameters to use for the instantiation of evaluators (relevant if useCrossValidation==False)
:param cross_validator_params: parameters to use for the instantiation of cross-validators (relevant if useCrossValidation==True)
:param test_io_data_dict: a dictionary mapping from names to the test data sets to use for evaluation or to None.
Entries with non-None values will be used for evaluation of the models that were trained on the respective io_data_dict.
If passed, the keys need to be a superset of io_data_dict's keys (note that the values may be None, e.g.
if you want to use test data sets for some entries, and splitting of the io_data for others).
If not None, cross-validation cannot be used when calling ``compare_models``.
"""
if test_io_data_dict is not None:
missing_keys = set(io_data_dict).difference(test_io_data_dict)
if len(missing_keys) > 0:
raise ValueError(
"If test_io_data_dict is passed, its keys must be a superset of the io_data_dict's keys."
f"However, found missing_keys: {missing_keys}")
self.io_data_dict = io_data_dict
self.test_io_data_dict = test_io_data_dict

self.key_name = key_name
self.evaluator_params = evaluator_params
self.cross_validator_params = cross_validator_params
Expand All @@ -612,25 +630,34 @@ def compare_models(self,
"""
:param model_factories: a sequence of factory functions for the creation of models to evaluate; every factory must result
in a model with a fixed model name (otherwise results cannot be correctly aggregated)
:param use_cross_validation: whether to use cross-validation (rather than a single split) for model evaluation
:param use_cross_validation: whether to use cross-validation (rather than a single split) for model evaluation.
This can only be used if the instance's ``test_io_data_dict`` is None.
:param result_writer: a writer with which to store results; if None, results are not stored
:param write_per_dataset_results: whether to use resultWriter (if not None) in order to generate detailed results for each
dataset in a subdirectory named according to the name of the dataset
:param write_csvs: whether to write metrics table to CSV files
:param column_name_for_model_ranking: column name to use for ranking models
:param rank_max: if true, use max for ranking, else min
:param add_combined_eval_stats: whether to also report, for each model, evaluation metrics on the combined set data points from
all EvalStats objects.
Note that for classification, this is only possible if all individual experiments use the same set of class labels.
:param create_metric_distribution_plots: whether to create, for each model, plots of the distribution of each metric across the
datasets (applies only if resultWriter is not None)
datasets (applies only if result_writer is not None)
:param create_combined_eval_stats_plots: whether to combine, for each type of model, the EvalStats objects from the individual
experiments into a single objects that holds all results and use it to create plots reflecting the overall result (applies only
if resultWriter is not None).
Note that for classification, this is only possible if all individual experiments use the same set of class labels.
:param distribution_plots_cdf: whether to create CDF plots for the metric distributions. Applies only if
create_metric_distribution_plots is True and result_writer is not None.
:param distribution_plots_cdf_complementary: whether to plot the complementary cdf instead of the regular cdf, provided that
distribution_plots_cdf is True.
:param visitors: visitors which may process individual results. Plots generated by visitors are created/collected at the end of the
comparison.
:return: an object containing the full comparison results
"""
if self.test_io_data_dict and use_cross_validation:
raise ValueError("Cannot use cross-validation when `test_io_data_dict` is specified")

all_results_df = pd.DataFrame()
eval_stats_by_model_name = defaultdict(list)
results_by_model_name: Dict[str, List[ModelComparisonData.Result]] = defaultdict(list)
Expand Down Expand Up @@ -659,8 +686,9 @@ def compare_models(self,
else:
raise ValueError("The models have to be either all regression models or all classification, not a mixture")

test_io_data = self.test_io_data_dict[key] if self.test_io_data_dict is not None else None
ev = create_evaluation_util(inputOutputData, is_regression=is_regression, evaluator_params=self.evaluator_params,
cross_validator_params=self.cross_validator_params)
cross_validator_params=self.cross_validator_params, test_io_data=test_io_data)

if plot_collector is None:
plot_collector = ev.eval_stats_plot_collector
Expand Down Expand Up @@ -918,7 +946,7 @@ def create_distribution_plots(self, result_writer: ResultWriter, cdf=True, cdf_c

:param result_writer: the result writer
:param cdf: whether to additionally plot, for each distribution, the cumulative distribution function
:param cdf_complementary: whether to plot the complementary cdf, provided that ``cdf`` is True
:param cdf_complementary: whether to plot the complementary cdf instead of the regular cdf, provided that ``cdf`` is True
"""
for modelName in self.get_model_names():
eval_stats_collection = self.get_eval_stats_collection(modelName)
Expand Down
Loading