diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index 9e59601a8..940c126ce 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -202,7 +202,6 @@ def __init__( self._multiprocessing_context = 'forkserver' if self.n_jobs == 1: self._multiprocessing_context = 'fork' - self._dask_client = SingleThreadedClient() self.InputValidator: Optional[BaseInputValidator] = None @@ -698,8 +697,9 @@ def _search( self, optimize_metric: str, dataset: BaseDataset, - budget_type: Optional[str] = None, - budget: Optional[float] = None, + budget_type: str = 'epochs', + min_budget: int = 5, + max_budget: int = 50, total_walltime_limit: int = 100, func_eval_time_limit_secs: Optional[int] = None, enable_traditional_pipeline: bool = True, @@ -728,13 +728,36 @@ def _search( Providing X_train, y_train and dataset together is not supported. optimize_metric (str): name of the metric that is used to evaluate a pipeline. - budget_type (Optional[str]): + budget_type (str): Type of budget to be used when fitting the pipeline. - Either 'epochs' or 'runtime'. If not provided, uses - the default in the pipeline config ('epochs') - budget (Optional[float]): - Budget to fit a single run of the pipeline. If not - provided, uses the default in the pipeline config + It can be one of: + + 'epochs': The training of each pipeline will be terminated after + a number of epochs have passed. This number of epochs is determined by the + budget argument of this method. + + 'runtime': The training of each pipeline will be terminated after + a number of seconds have passed. This number of seconds is determined by the + budget argument of this method. The overall fitting time of a pipeline is + controlled by func_eval_time_limit_secs. 'runtime' only controls the allocated + time to train a pipeline, but it does not consider the overall time it takes + to create a pipeline (data loading and preprocessing, other i/o operations, etc.). + budget_type will determine the units of min_budget/max_budget. If budget_type=='epochs' + is used, min_budget will refer to epochs whereas if budget_type=='runtime' then + min_budget will refer to seconds. + min_budget (int): + Auto-PyTorch uses `Hyperband _` to + trade-off resources between running many pipelines at min_budget and + running the top performing pipelines on max_budget. + min_budget states the minimum resource allocation a pipeline should have + so that we can compare and quickly discard bad performing models. + For example, if the budget_type is epochs, and min_budget=5, then we will + run every pipeline to a minimum of 5 epochs before performance comparison. + max_budget (int): + Auto-PyTorch uses `Hyperband _` to + trade-off resources between running many pipelines at min_budget and + running the top performing pipelines on max_budget. + max_budget states the maximum resource allocation a pipeline is going to + be ran. For example, if the budget_type is epochs, and max_budget=50, + then the pipeline training will be terminated after 50 epochs. total_walltime_limit (int), (default=100): Time limit in seconds for the search of appropriate models. By increasing this value, autopytorch has a higher @@ -843,23 +866,27 @@ def _search( self.search_space = self.get_search_space(dataset) - budget_config: Dict[str, Union[float, str]] = {} - if budget_type is not None and budget is not None: - budget_config['budget_type'] = budget_type - budget_config[budget_type] = budget - elif budget_type is not None or budget is not None: - raise ValueError( - "budget type was not specified in budget_config" - ) + # Incorporate budget to pipeline config + if budget_type not in ('epochs', 'runtime'): + raise ValueError("Budget type must be one ('epochs', 'runtime')" + f" yet {budget_type} was provided") + self.pipeline_options['budget_type'] = budget_type + + # Here the budget is set to max because the SMAC intensifier can be: + # Hyperband: in this case the budget is determined on the fly and overwritten + # by the ExecuteTaFuncWithQueue + # SimpleIntensifier (and others): in this case, we use max_budget as a target + # budget, and hece the below line is honored + self.pipeline_options[budget_type] = max_budget if self.task_type is None: raise ValueError("Cannot interpret task type from the dataset") # If no dask client was provided, we create one, so that we can # start a ensemble process in parallel to smbo optimize - if ( - dask_client is None and (self.ensemble_size > 0 or self.n_jobs > 1) - ): + if self.n_jobs == 1: + self._dask_client = SingleThreadedClient() + elif dask_client is None: self._create_dask_client() else: self._dask_client = dask_client @@ -878,7 +905,7 @@ def _search( # Make sure that at least 2 models are created for the ensemble process num_models = time_left_for_modelfit // func_eval_time_limit_secs - if num_models < 2: + if num_models < 2 and self.ensemble_size > 0: func_eval_time_limit_secs = time_left_for_modelfit // 2 self._logger.warning( "Capping the func_eval_time_limit_secs to {} to have " @@ -978,7 +1005,9 @@ def _search( all_supported_metrics=self._all_supported_metrics, smac_scenario_args=smac_scenario_args, get_smac_object_callback=get_smac_object_callback, - pipeline_config={**self.pipeline_options, **budget_config}, + pipeline_config=self.pipeline_options, + min_budget=min_budget, + max_budget=max_budget, ensemble_callback=proc_ensemble, logger_port=self._logger_port, # We do not increase the num_run here, this is something @@ -1046,7 +1075,6 @@ def _search( def refit( self, dataset: BaseDataset, - budget_config: Dict[str, Union[int, str]] = {}, split_id: int = 0 ) -> "BaseTask": """ @@ -1058,14 +1086,16 @@ def refit( This methods fits all models found during a call to fit on the data given. This method may also be used together with holdout to avoid only using 66% of the training data to fit the final model. + + Refit uses the estimator pipeline_config attribute, which the user + can interact via the get_pipeline_config()/set_pipeline_config() + methods. + Args: dataset: (Dataset) The argument that will provide the dataset splits. It can either be a dictionary with the splits, or the dataset object which can generate the splits based on different restrictions. - budget_config: (Optional[Dict[str, Union[int, str]]]) - can contain keys from 'budget_type' and the budget - specified using 'epochs' or 'runtime'. split_id: (int) split id to fit on. Returns: @@ -1096,7 +1126,7 @@ def refit( 'split_id': split_id, 'num_run': self._backend.get_next_num_run(), }) - X.update({**self.pipeline_options, **budget_config}) + X.update(self.pipeline_options) if self.models_ is None or len(self.models_) == 0 or self.ensemble_ is None: self._load_models() @@ -1120,21 +1150,22 @@ def refit( def fit(self, dataset: BaseDataset, - budget_config: Dict[str, Union[int, str]] = {}, pipeline_config: Optional[Configuration] = None, split_id: int = 0) -> BasePipeline: """ Fit a pipeline on the given task for the budget. A pipeline configuration can be specified if None, uses default + + Fit uses the estimator pipeline_config attribute, which the user + can interact via the get_pipeline_config()/set_pipeline_config() + methods. + Args: dataset: (Dataset) The argument that will provide the dataset splits. It can either be a dictionary with the splits, or the dataset object which can generate the splits based on different restrictions. - budget_config: (Optional[Dict[str, Union[int, str]]]) - can contain keys from 'budget_type' and the budget - specified using 'epochs' or 'runtime'. split_id: (int) (default=0) split id to fit on. pipeline_config: (Optional[Configuration]) @@ -1175,7 +1206,7 @@ def fit(self, 'split_id': split_id, 'num_run': self._backend.get_next_num_run(), }) - X.update({**self.pipeline_options, **budget_config}) + X.update(self.pipeline_options) fit_and_suppress_warnings(self._logger, pipeline, X, y=None) diff --git a/autoPyTorch/api/tabular_classification.py b/autoPyTorch/api/tabular_classification.py index 9da96ef94..20be4346d 100644 --- a/autoPyTorch/api/tabular_classification.py +++ b/autoPyTorch/api/tabular_classification.py @@ -110,8 +110,9 @@ def search( X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, dataset_name: Optional[str] = None, - budget_type: Optional[str] = None, - budget: Optional[float] = None, + budget_type: str = 'epochs', + min_budget: int = 5, + max_budget: int = 50, total_walltime_limit: int = 100, func_eval_time_limit_secs: Optional[int] = None, enable_traditional_pipeline: bool = True, @@ -137,15 +138,38 @@ def search( be provided to track the generalization performance of each stage. optimize_metric (str): name of the metric that is used to evaluate a pipeline. - budget_type (Optional[str]): + budget_type (str): Type of budget to be used when fitting the pipeline. - Either 'epochs' or 'runtime'. If not provided, uses - the default in the pipeline config ('epochs') - budget (Optional[float]): - Budget to fit a single run of the pipeline. If not - provided, uses the default in the pipeline config - total_walltime_limit (int), (default=100): - Time limit in seconds for the search of appropriate models. + It can be one of: + + 'epochs': The training of each pipeline will be terminated after + a number of epochs have passed. This number of epochs is determined by the + budget argument of this method. + + 'runtime': The training of each pipeline will be terminated after + a number of seconds have passed. This number of seconds is determined by the + budget argument of this method. The overall fitting time of a pipeline is + controlled by func_eval_time_limit_secs. 'runtime' only controls the allocated + time to train a pipeline, but it does not consider the overall time it takes + to create a pipeline (data loading and preprocessing, other i/o operations, etc.). + budget_type will determine the units of min_budget/max_budget. If budget_type=='epochs' + is used, min_budget will refer to epochs whereas if budget_type=='runtime' then + min_budget will refer to seconds. + min_budget (int): + Auto-PyTorch uses `Hyperband _` to + trade-off resources between running many pipelines at min_budget and + running the top performing pipelines on max_budget. + min_budget states the minimum resource allocation a pipeline should have + so that we can compare and quickly discard bad performing models. + For example, if the budget_type is epochs, and min_budget=5, then we will + run every pipeline to a minimum of 5 epochs before performance comparison. + max_budget (int): + Auto-PyTorch uses `Hyperband _` to + trade-off resources between running many pipelines at min_budget and + running the top performing pipelines on max_budget. + max_budget states the maximum resource allocation a pipeline is going to + be ran. For example, if the budget_type is epochs, and max_budget=50, + then the pipeline training will be terminated after 50 epochs. + total_walltime_limit (int), (default=100): Time limit + in seconds for the search of appropriate models. By increasing this value, autopytorch has a higher chance of finding better models. func_eval_time_limit_secs (int), (default=None): @@ -234,7 +258,8 @@ def search( dataset=self.dataset, optimize_metric=optimize_metric, budget_type=budget_type, - budget=budget, + min_budget=min_budget, + max_budget=max_budget, total_walltime_limit=total_walltime_limit, func_eval_time_limit_secs=func_eval_time_limit_secs, enable_traditional_pipeline=enable_traditional_pipeline, diff --git a/autoPyTorch/api/tabular_regression.py b/autoPyTorch/api/tabular_regression.py index 599856ce8..b88bf7cd9 100644 --- a/autoPyTorch/api/tabular_regression.py +++ b/autoPyTorch/api/tabular_regression.py @@ -53,23 +53,23 @@ class TabularRegressionTask(BaseTask): """ def __init__( - self, - seed: int = 1, - n_jobs: int = 1, - logging_config: Optional[Dict] = None, - ensemble_size: int = 50, - ensemble_nbest: int = 50, - max_models_on_disc: int = 50, - temporary_directory: Optional[str] = None, - output_directory: Optional[str] = None, - delete_tmp_folder_after_terminate: bool = True, - delete_output_folder_after_terminate: bool = True, - include_components: Optional[Dict] = None, - exclude_components: Optional[Dict] = None, - resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, - resampling_strategy_args: Optional[Dict[str, Any]] = None, - backend: Optional[Backend] = None, - search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None + self, + seed: int = 1, + n_jobs: int = 1, + logging_config: Optional[Dict] = None, + ensemble_size: int = 50, + ensemble_nbest: int = 50, + max_models_on_disc: int = 50, + temporary_directory: Optional[str] = None, + output_directory: Optional[str] = None, + delete_tmp_folder_after_terminate: bool = True, + delete_output_folder_after_terminate: bool = True, + include_components: Optional[Dict] = None, + exclude_components: Optional[Dict] = None, + resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, + resampling_strategy_args: Optional[Dict[str, Any]] = None, + backend: Optional[Backend] = None, + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None ): super().__init__( seed=seed, @@ -102,8 +102,9 @@ def search( X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, dataset_name: Optional[str] = None, - budget_type: Optional[str] = None, - budget: Optional[float] = None, + budget_type: str = 'epochs', + min_budget: int = 5, + max_budget: int = 50, total_walltime_limit: int = 100, func_eval_time_limit_secs: Optional[int] = None, enable_traditional_pipeline: bool = True, @@ -129,13 +130,36 @@ def search( be provided to track the generalization performance of each stage. optimize_metric (str): name of the metric that is used to evaluate a pipeline. - budget_type (Optional[str]): + budget_type (str): Type of budget to be used when fitting the pipeline. - Either 'epochs' or 'runtime'. If not provided, uses - the default in the pipeline config ('epochs') - budget (Optional[float]): - Budget to fit a single run of the pipeline. If not - provided, uses the default in the pipeline config + It can be one of: + + 'epochs': The training of each pipeline will be terminated after + a number of epochs have passed. This number of epochs is determined by the + budget argument of this method. + + 'runtime': The training of each pipeline will be terminated after + a number of seconds have passed. This number of seconds is determined by the + budget argument of this method. The overall fitting time of a pipeline is + controlled by func_eval_time_limit_secs. 'runtime' only controls the allocated + time to train a pipeline, but it does not consider the overall time it takes + to create a pipeline (data loading and preprocessing, other i/o operations, etc.). + budget_type will determine the units of min_budget/max_budget. If budget_type=='epochs' + is used, min_budget will refer to epochs whereas if budget_type=='runtime' then + min_budget will refer to seconds. + min_budget (int): + Auto-PyTorch uses `Hyperband _` to + trade-off resources between running many pipelines at min_budget and + running the top performing pipelines on max_budget. + min_budget states the minimum resource allocation a pipeline should have + so that we can compare and quickly discard bad performing models. + For example, if the budget_type is epochs, and min_budget=5, then we will + run every pipeline to a minimum of 5 epochs before performance comparison. + max_budget (int): + Auto-PyTorch uses `Hyperband _` to + trade-off resources between running many pipelines at min_budget and + running the top performing pipelines on max_budget. + max_budget states the maximum resource allocation a pipeline is going to + be ran. For example, if the budget_type is epochs, and max_budget=50, + then the pipeline training will be terminated after 50 epochs. total_walltime_limit (int), (default=100): Time limit in seconds for the search of appropriate models. By increasing this value, autopytorch has a higher @@ -227,7 +251,8 @@ def search( dataset=self.dataset, optimize_metric=optimize_metric, budget_type=budget_type, - budget=budget, + min_budget=min_budget, + max_budget=max_budget, total_walltime_limit=total_walltime_limit, func_eval_time_limit_secs=func_eval_time_limit_secs, enable_traditional_pipeline=enable_traditional_pipeline, diff --git a/autoPyTorch/configs/default_pipeline_options.json b/autoPyTorch/configs/default_pipeline_options.json index c5481080c..9c73fc181 100644 --- a/autoPyTorch/configs/default_pipeline_options.json +++ b/autoPyTorch/configs/default_pipeline_options.json @@ -1,11 +1,10 @@ { - "device": "cpu", - "budget_type": "epochs", - "min_epochs": 5, - "epochs": 50, - "runtime": 3600, - "torch_num_threads": 1, - "early_stopping": 20, - "use_tensorboard_logger": "False", - "metrics_during_training": "True" + "device": "cpu", + "budget_type": "epochs", + "epochs": 50, + "runtime": 3600, + "torch_num_threads": 1, + "early_stopping": 20, + "use_tensorboard_logger": "False", + "metrics_during_training": "True" } diff --git a/autoPyTorch/evaluation/abstract_evaluator.py b/autoPyTorch/evaluation/abstract_evaluator.py index c6e28cfaa..0f998d73c 100644 --- a/autoPyTorch/evaluation/abstract_evaluator.py +++ b/autoPyTorch/evaluation/abstract_evaluator.py @@ -515,6 +515,12 @@ def __init__(self, backend: Backend, # If the budget is epochs, we want to limit that in the fit dictionary if self.budget_type == 'epochs': self.fit_dictionary['epochs'] = budget + self.fit_dictionary.pop('runtime', None) + elif self.budget_type == 'runtime': + self.fit_dictionary['runtime'] = budget + self.fit_dictionary.pop('epochs', None) + else: + raise ValueError(f"Unsupported budget type {self.budget_type} provided") self.num_run = 0 if num_run is None else num_run diff --git a/autoPyTorch/evaluation/tae.py b/autoPyTorch/evaluation/tae.py index 969a7a785..96eae3351 100644 --- a/autoPyTorch/evaluation/tae.py +++ b/autoPyTorch/evaluation/tae.py @@ -209,9 +209,14 @@ def run_wrapper( ) else: if run_info.budget == 0: - run_info = run_info._replace(budget=self.pipeline_config[self.budget_type]) - elif run_info.budget <= 0 or run_info.budget > 100: - raise ValueError('Illegal value for budget, must be >0 and <=100, but is %f' % + # SMAC can return budget zero for intensifiers that don't have a concept + # of budget, for example a simple bayesian optimization intensifier. + # Budget determines how our pipeline trains, which can be via runtime or epochs + epochs_budget = self.pipeline_config.get('epochs', np.inf) + runtime_budget = self.pipeline_config.get('runtime', np.inf) + run_info = run_info._replace(budget=min(epochs_budget, runtime_budget)) + elif run_info.budget <= 0: + raise ValueError('Illegal value for budget, must be greater than zero but is %f' % run_info.budget) if self.budget_type not in ('epochs', 'runtime'): raise ValueError("Illegal value for budget type, must be one of " diff --git a/autoPyTorch/optimizer/smbo.py b/autoPyTorch/optimizer/smbo.py index 094c92c25..4c16ca4ce 100644 --- a/autoPyTorch/optimizer/smbo.py +++ b/autoPyTorch/optimizer/smbo.py @@ -111,6 +111,8 @@ def __init__(self, search_space_updates: typing.Optional[HyperparameterSearchSpaceUpdates] = None, portfolio_selection: typing.Optional[str] = None, pynisher_context: str = 'spawn', + min_budget: int = 5, + max_budget: int = 50, ): """ Interface to SMAC. This method calls the SMAC optimize method, and allows @@ -169,7 +171,22 @@ def __init__(self, configurations, similar to (autoPyTorch/configs/greedy_portfolio.json). Additionally, the keyword 'greedy' is supported, which would use the default portfolio from - `AutoPyTorch Tabular ` + `AutoPyTorch Tabular _` + min_budget (int): + Auto-PyTorch uses `Hyperband _` to + trade-off resources between running many pipelines at min_budget and + running the top performing pipelines on max_budget. + min_budget states the minimum resource allocation a pipeline should have + so that we can compare and quickly discard bad performing models. + For example, if the budget_type is epochs, and min_budget=5, then we will + run every pipeline to a minimum of 5 epochs before performance comparison. + max_budget (int): + Auto-PyTorch uses `Hyperband _` to + trade-off resources between running many pipelines at min_budget and + running the top performing pipelines on max_budget. + max_budget states the maximum resource allocation a pipeline is going to + be ran. For example, if the budget_type is epochs, and max_budget=50, + then the pipeline training will be terminated after 50 epochs. """ super(AutoMLSMBO, self).__init__() # data related @@ -208,6 +225,8 @@ def __init__(self, self.smac_scenario_args = smac_scenario_args self.get_smac_object_callback = get_smac_object_callback self.pynisher_context = pynisher_context + self.min_budget = min_budget + self.max_budget = max_budget self.ensemble_callback = ensemble_callback @@ -326,17 +345,14 @@ def run_smbo(self, func: typing.Optional[typing.Callable] = None ) scenario_dict.update(self.smac_scenario_args) - initial_budget = self.pipeline_config['min_epochs'] - max_budget = self.pipeline_config['epochs'] - if self.get_smac_object_callback is not None: smac = self.get_smac_object_callback(scenario_dict=scenario_dict, seed=seed, ta=ta, ta_kwargs=ta_kwargs, n_jobs=self.n_jobs, - initial_budget=initial_budget, - max_budget=max_budget, + initial_budget=self.min_budget, + max_budget=self.max_budget, dask_client=self.dask_client, initial_configurations=self.initial_configurations) else: @@ -345,8 +361,8 @@ def run_smbo(self, func: typing.Optional[typing.Callable] = None ta=ta, ta_kwargs=ta_kwargs, n_jobs=self.n_jobs, - initial_budget=initial_budget, - max_budget=max_budget, + initial_budget=self.min_budget, + max_budget=self.max_budget, dask_client=self.dask_client, initial_configurations=self.initial_configurations) diff --git a/autoPyTorch/pipeline/components/training/trainer/__init__.py b/autoPyTorch/pipeline/components/training/trainer/__init__.py index c490a405c..1eb24443a 100755 --- a/autoPyTorch/pipeline/components/training/trainer/__init__.py +++ b/autoPyTorch/pipeline/components/training/trainer/__init__.py @@ -242,7 +242,7 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic if X["torch_num_threads"] > 0: torch.set_num_threads(X["torch_num_threads"]) - budget_tracker = BudgetTracker( + self.budget_tracker = BudgetTracker( budget_type=X['budget_type'], max_runtime=X['runtime'] if 'runtime' in X else None, max_epochs=X['epochs'] if 'epochs' in X else None, @@ -260,7 +260,7 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic metrics=metrics, criterion=get_loss(X['dataset_properties'], name=additional_losses), - budget_tracker=budget_tracker, + budget_tracker=self.budget_tracker, optimizer=X['optimizer'], device=get_device_from_fit_dictionary(X), metrics_during_training=X['metrics_during_training'], @@ -322,7 +322,10 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic self.logger.debug(self.run_summary.repr_last_epoch()) # Reached max epoch on next iter, don't even go there - if budget_tracker.is_max_epoch_reached(epoch + 1): + if ( + self.budget_tracker.is_max_epoch_reached(epoch + 1) + or self.budget_tracker.is_max_time_reached() + ): break epoch += 1 diff --git a/test/test_api/test_base_api.py b/test/test_api/test_base_api.py index 8949f9f28..6c74515b1 100644 --- a/test/test_api/test_base_api.py +++ b/test/test_api/test_base_api.py @@ -7,6 +7,9 @@ import pytest +from smac.runhistory.runhistory import RunHistory +from smac.tae.serial_runner import SerialRunner + from autoPyTorch.api.base_task import BaseTask, _pipeline_predict from autoPyTorch.constants import TABULAR_CLASSIFICATION, TABULAR_REGRESSION from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline @@ -94,8 +97,45 @@ def test_set_pipeline_config(): estimator = BaseTask() pipeline_options = {"device": "cuda", "budget_type": "epochs", - "min_epochs": 10, "epochs": 51, "runtime": 360} estimator.set_pipeline_config(**pipeline_options) assert pipeline_options.items() <= estimator.get_pipeline_options().items() + + +@pytest.mark.parametrize("fit_dictionary_tabular", ['classification_categorical_only'], indirect=True) +@pytest.mark.parametrize( + "min_budget,max_budget,budget_type,expected", [ + (5, 75, 'epochs', {'budget_type': 'epochs', 'epochs': 75}), + (3, 50, 'runtime', {'budget_type': 'runtime', 'runtime': 50}), + ]) +def test_pipeline_get_budget(fit_dictionary_tabular, min_budget, max_budget, budget_type, expected): + estimator = BaseTask(task_type='tabular_classification', ensemble_size=0) + + # Fixture pipeline config + default_pipeline_config = { + 'device': 'cpu', 'budget_type': 'epochs', 'epochs': 50, 'runtime': 3600, + 'torch_num_threads': 1, 'early_stopping': 20, 'use_tensorboard_logger': False, + 'metrics_during_training': True, 'optimize_metric': 'accuracy' + } + default_pipeline_config.update(expected) + + # Create pre-requisites + dataset = fit_dictionary_tabular['backend'].load_datamanager() + pipeline_fit = unittest.mock.Mock() + + smac = unittest.mock.Mock() + smac.solver.runhistory = RunHistory() + smac.solver.intensifier.traj_logger.trajectory = [] + smac.solver.tae_runner = unittest.mock.Mock(spec=SerialRunner) + smac.solver.tae_runner.budget_type = 'epochs' + with unittest.mock.patch('autoPyTorch.optimizer.smbo.get_smac_object') as smac_mock: + smac_mock.return_value = smac + estimator._search(optimize_metric='accuracy', dataset=dataset, tae_func=pipeline_fit, + min_budget=min_budget, max_budget=max_budget, budget_type=budget_type, + enable_traditional_pipeline=False, + total_walltime_limit=10, func_eval_time_limit_secs=5, + load_models=False) + assert list(smac_mock.call_args)[1]['ta_kwargs']['pipeline_config'] == default_pipeline_config + assert list(smac_mock.call_args)[1]['max_budget'] == max_budget + assert list(smac_mock.call_args)[1]['initial_budget'] == min_budget diff --git a/test/test_evaluation/test_train_evaluator.py b/test/test_evaluation/test_train_evaluator.py index ae35c097b..234eaae71 100644 --- a/test/test_evaluation/test_train_evaluator.py +++ b/test/test_evaluation/test_train_evaluator.py @@ -87,6 +87,7 @@ def tearDown(self): @unittest.mock.patch('autoPyTorch.pipeline.tabular_classification.TabularClassificationPipeline') def test_holdout(self, pipeline_mock): + pipeline_mock.fit_dictionary = {'budget_type': 'epochs', 'epochs': 50} # Binary iris, contains 69 train samples, 31 test samples D = get_binary_classification_datamanager() pipeline_mock.predict_proba.side_effect = \ @@ -99,7 +100,8 @@ def test_holdout(self, pipeline_mock): backend_api.load_datamanager = lambda: D queue_ = multiprocessing.Queue() - evaluator = TrainEvaluator(backend_api, queue_, configuration=configuration, metric=accuracy, budget=0) + evaluator = TrainEvaluator(backend_api, queue_, configuration=configuration, metric=accuracy, budget=0, + pipeline_config={'budget_type': 'epochs', 'epochs': 50}) evaluator.file_output = unittest.mock.Mock(spec=evaluator.file_output) evaluator.file_output.return_value = (None, {}) @@ -137,7 +139,8 @@ def test_cv(self, pipeline_mock): backend_api.load_datamanager = lambda: D queue_ = multiprocessing.Queue() - evaluator = TrainEvaluator(backend_api, queue_, configuration=configuration, metric=accuracy, budget=0) + evaluator = TrainEvaluator(backend_api, queue_, configuration=configuration, metric=accuracy, budget=0, + pipeline_config={'budget_type': 'epochs', 'epochs': 50}) evaluator.file_output = unittest.mock.Mock(spec=evaluator.file_output) evaluator.file_output.return_value = (None, {}) @@ -241,7 +244,8 @@ def test_predict_proba_binary_classification(self, mock): configuration = unittest.mock.Mock(spec=Configuration) queue_ = multiprocessing.Queue() - evaluator = TrainEvaluator(self.backend_mock, queue_, configuration=configuration, metric=accuracy, budget=0) + evaluator = TrainEvaluator(self.backend_mock, queue_, configuration=configuration, metric=accuracy, budget=0, + pipeline_config={'budget_type': 'epochs', 'epochs': 50}) evaluator.fit_predict_and_loss() Y_optimization_pred = self.backend_mock.save_numrun_to_dir.call_args_list[0][1][ diff --git a/test/test_pipeline/test_tabular_classification.py b/test/test_pipeline/test_tabular_classification.py index 0184d84f3..3f962c9d1 100644 --- a/test/test_pipeline/test_tabular_classification.py +++ b/test/test_pipeline/test_tabular_classification.py @@ -450,3 +450,35 @@ def test_pipeline_score(fit_dictionary_tabular_dummy): # we should be able to get a decent score on this dummy data assert accuracy >= 0.8, f"Pipeline:{pipeline} Config:{config} FitDict: {fit_dictionary_tabular_dummy}" + + +@pytest.mark.parametrize("fit_dictionary_tabular_dummy", ["classification"], indirect=True) +def test_train_pipeline_with_runtime(fit_dictionary_tabular_dummy): + """This test makes sure that the pipeline is able to achieve a decent score on dummy data + given the default configuration""" + + # Convert the training to runtime + fit_dictionary_tabular_dummy.pop('epochs', None) + fit_dictionary_tabular_dummy['budget_type'] = 'runtime' + fit_dictionary_tabular_dummy['runtime'] = 3 + fit_dictionary_tabular_dummy['early_stopping'] = -1 + + pipeline = TabularClassificationPipeline( + dataset_properties=fit_dictionary_tabular_dummy['dataset_properties']) + + cs = pipeline.get_hyperparameter_search_space() + config = cs.get_default_configuration() + pipeline.set_hyperparameters(config) + + pipeline.fit(fit_dictionary_tabular_dummy) + run_summary = pipeline.named_steps['trainer'].run_summary + budget_tracker = pipeline.named_steps['trainer'].budget_tracker + assert budget_tracker.budget_type == 'runtime' + assert budget_tracker.max_runtime == 3 + assert budget_tracker.is_max_time_reached() + + # There is no epoch limitation + assert not budget_tracker.is_max_epoch_reached(epoch=np.inf) + + # More than 200 epochs would have pass in 3 seconds for this dataset + assert len(run_summary.performance_tracker['start_time']) > 100