From 9559849c9deb2d77f6f05e80c215d3cfbd9b30d5 Mon Sep 17 00:00:00 2001 From: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com> Date: Fri, 12 Nov 2021 21:03:42 +0100 Subject: [PATCH] [FIX] Additional metrics during training (#316) * additional_metrics during training * fix flake * Add test for unsupported budget type * Apply suggestions from code review Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com> Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com> --- autoPyTorch/evaluation/abstract_evaluator.py | 80 ++++++++++++++----- .../test_abstract_evaluator.py | 32 ++++++++ test/test_evaluation/test_train_evaluator.py | 32 ++++++++ 3 files changed, 123 insertions(+), 21 deletions(-) diff --git a/autoPyTorch/evaluation/abstract_evaluator.py b/autoPyTorch/evaluation/abstract_evaluator.py index d0363a248..b926a50a7 100644 --- a/autoPyTorch/evaluation/abstract_evaluator.py +++ b/autoPyTorch/evaluation/abstract_evaluator.py @@ -490,37 +490,23 @@ def __init__(self, backend: Backend, )) self.additional_metrics: Optional[List[autoPyTorchMetric]] = None + metrics_dict: Optional[Dict[str, List[str]]] = None if all_supported_metrics: self.additional_metrics = get_metrics(dataset_properties=self.dataset_properties, all_supported_metrics=all_supported_metrics) + # Update fit dictionary with metrics passed to the evaluator + metrics_dict = {'additional_metrics': []} + metrics_dict['additional_metrics'].append(self.metric.name) + for metric in self.additional_metrics: + metrics_dict['additional_metrics'].append(metric.name) - self.fit_dictionary: Dict[str, Any] = {'dataset_properties': self.dataset_properties} self._init_params = init_params - self.fit_dictionary.update({ - 'X_train': self.X_train, - 'y_train': self.y_train, - 'X_test': self.X_test, - 'y_test': self.y_test, - 'backend': self.backend, - 'logger_port': logger_port, - 'optimize_metric': self.metric.name - }) + assert self.pipeline_class is not None, "Could not infer pipeline class" pipeline_config = pipeline_config if pipeline_config is not None \ else self.pipeline_class.get_default_pipeline_options() self.budget_type = pipeline_config['budget_type'] if budget_type is None else budget_type self.budget = pipeline_config[self.budget_type] if budget == 0 else budget - self.fit_dictionary = {**pipeline_config, **self.fit_dictionary} - - # If the budget is epochs, we want to limit that in the fit dictionary - if self.budget_type == 'epochs': - self.fit_dictionary['epochs'] = budget - self.fit_dictionary.pop('runtime', None) - elif self.budget_type == 'runtime': - self.fit_dictionary['runtime'] = budget - self.fit_dictionary.pop('epochs', None) - else: - raise ValueError(f"Unsupported budget type {self.budget_type} provided") self.num_run = 0 if num_run is None else num_run @@ -533,6 +519,7 @@ def __init__(self, backend: Backend, port=logger_port, ) + self._init_fit_dictionary(logger_port=logger_port, pipeline_config=pipeline_config, metrics_dict=metrics_dict) self.Y_optimization: Optional[np.ndarray] = None self.Y_actual_train: Optional[np.ndarray] = None self.pipelines: Optional[List[BaseEstimator]] = None @@ -540,6 +527,57 @@ def __init__(self, backend: Backend, self.logger.debug("Fit dictionary in Abstract evaluator: {}".format(dict_repr(self.fit_dictionary))) self.logger.debug("Search space updates :{}".format(self.search_space_updates)) + def _init_fit_dictionary( + self, + logger_port: int, + pipeline_config: Dict[str, Any], + metrics_dict: Optional[Dict[str, List[str]]] = None, + ) -> None: + """ + Initialises the fit dictionary + + Args: + logger_port (int): + Logging is performed using a socket-server scheme to be robust against many + parallel entities that want to write to the same file. This integer states the + socket port for the communication channel. + pipeline_config (Dict[str, Any]): + Defines the content of the pipeline being evaluated. For example, it + contains pipeline specific settings like logging name, or whether or not + to use tensorboard. + metrics_dict (Optional[Dict[str, List[str]]]): + Contains a list of metric names to be evaluated in Trainer with key `additional_metrics`. Defaults to None. + + Returns: + None + """ + + self.fit_dictionary: Dict[str, Any] = {'dataset_properties': self.dataset_properties} + + if metrics_dict is not None: + self.fit_dictionary.update(metrics_dict) + + self.fit_dictionary.update({ + 'X_train': self.X_train, + 'y_train': self.y_train, + 'X_test': self.X_test, + 'y_test': self.y_test, + 'backend': self.backend, + 'logger_port': logger_port, + 'optimize_metric': self.metric.name + }) + + self.fit_dictionary.update(pipeline_config) + # If the budget is epochs, we want to limit that in the fit dictionary + if self.budget_type == 'epochs': + self.fit_dictionary['epochs'] = self.budget + self.fit_dictionary.pop('runtime', None) + elif self.budget_type == 'runtime': + self.fit_dictionary['runtime'] = self.budget + self.fit_dictionary.pop('epochs', None) + else: + raise ValueError(f"budget type must be `epochs` or `runtime`, but got {self.budget_type}") + def _get_pipeline(self) -> BaseEstimator: """ Implements a pipeline object based on the self.configuration attribute. diff --git a/test/test_evaluation/test_abstract_evaluator.py b/test/test_evaluation/test_abstract_evaluator.py index 999c2dd7e..6cec57fb4 100644 --- a/test/test_evaluation/test_abstract_evaluator.py +++ b/test/test_evaluation/test_abstract_evaluator.py @@ -282,3 +282,35 @@ def test_file_output(self): '.autoPyTorch', 'runs', '1_0_1.0'))) shutil.rmtree(self.working_directory, ignore_errors=True) + + def test_error_unsupported_budget_type(self): + shutil.rmtree(self.working_directory, ignore_errors=True) + os.mkdir(self.working_directory) + + queue_mock = unittest.mock.Mock() + + context = BackendContext( + prefix='autoPyTorch', + temporary_directory=os.path.join(self.working_directory, 'tmp'), + output_directory=os.path.join(self.working_directory, 'out'), + delete_tmp_folder_after_terminate=True, + delete_output_folder_after_terminate=True, + ) + with unittest.mock.patch.object(Backend, 'load_datamanager') as load_datamanager_mock: + load_datamanager_mock.return_value = get_multiclass_classification_datamanager() + + backend = Backend(context, prefix='autoPyTorch') + + try: + AbstractEvaluator( + backend=backend, + output_y_hat_optimization=False, + queue=queue_mock, + pipeline_config={'budget_type': "error", 'error': 0}, + metric=accuracy, + budget=0, + configuration=1) + except Exception as e: + self.assertIsInstance(e, ValueError) + + shutil.rmtree(self.working_directory, ignore_errors=True) diff --git a/test/test_evaluation/test_train_evaluator.py b/test/test_evaluation/test_train_evaluator.py index 234eaae71..a3ff067f1 100644 --- a/test/test_evaluation/test_train_evaluator.py +++ b/test/test_evaluation/test_train_evaluator.py @@ -262,3 +262,35 @@ def test_get_results(self): self.assertEqual(len(result), 5) self.assertEqual(result[0][0], 0) self.assertAlmostEqual(result[0][1], 1.0) + + @unittest.mock.patch('autoPyTorch.pipeline.tabular_classification.TabularClassificationPipeline') + def test_additional_metrics_during_training(self, pipeline_mock): + pipeline_mock.fit_dictionary = {'budget_type': 'epochs', 'epochs': 50} + # Binary iris, contains 69 train samples, 31 test samples + D = get_binary_classification_datamanager() + pipeline_mock.predict_proba.side_effect = \ + lambda X, batch_size=None: np.tile([0.6, 0.4], (len(X), 1)) + pipeline_mock.side_effect = lambda **kwargs: pipeline_mock + pipeline_mock.get_additional_run_info.return_value = None + + # Binary iris, contains 69 train samples, 31 test samples + D = get_binary_classification_datamanager() + + configuration = unittest.mock.Mock(spec=Configuration) + backend_api = create(self.tmp_dir, self.output_dir, prefix='autoPyTorch') + backend_api.load_datamanager = lambda: D + queue_ = multiprocessing.Queue() + + evaluator = TrainEvaluator(backend_api, queue_, configuration=configuration, metric=accuracy, budget=0, + pipeline_config={'budget_type': 'epochs', 'epochs': 50}, all_supported_metrics=True) + evaluator.file_output = unittest.mock.Mock(spec=evaluator.file_output) + evaluator.file_output.return_value = (None, {}) + + evaluator.fit_predict_and_loss() + + rval = read_queue(evaluator.queue) + self.assertEqual(len(rval), 1) + result = rval[0] + self.assertIn('additional_run_info', result) + self.assertIn('opt_loss', result['additional_run_info']) + self.assertGreater(len(result['additional_run_info']['opt_loss'].keys()), 1)