From fae72a4de79b747918076292b0540a8e4602cb78 Mon Sep 17 00:00:00 2001 From: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com> Date: Wed, 28 Apr 2021 20:37:48 +0200 Subject: [PATCH] Refactoring base dataset splitting functions (#106) * [Fork from #105] Made CrossValFuncs and HoldOutFuncs class to group the functions * Modified time_series_dataset.py to be compatible with resampling_strategy.py * [fix]: back to the renamed version of CROSS_VAL_FN from temporal SplitFunc typing. * fixed flake8 issues in three files * fixed the flake8 issues * [refactor] Address the francisco's comments * [refactor] Adress the francisco's comments * [refactor] Address the doc-string issue in TransformSubset class * [fix] Address flake8 issues * [fix] Fix flake8 issue * [fix] Fix mypy issues raised by github check * [fix] Fix a mypy issue * [fix] Fix a contradiction in holdout_stratified_validation Since stratified splitting requires to shuffle by default and it raises error in the github check, I fixed this issue. * [fix] Address the francisco's review * [fix] Fix a mypy issue tabular_dataset.py * [fix] Address the francisco's comment about the self.dataset_name Since we would to use the dataset name which does not have any name, I decided to get self.dataset_name back to Optional[str]. * [fix] Fix mypy issues --- autoPyTorch/api/base_task.py | 22 +- autoPyTorch/datasets/base_dataset.py | 51 +++-- autoPyTorch/datasets/resampling_strategy.py | 228 ++++++++++++-------- autoPyTorch/datasets/time_series_dataset.py | 18 +- autoPyTorch/ensemble/ensemble_builder.py | 101 +++++---- 5 files changed, 240 insertions(+), 180 deletions(-) diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index c4fa0e7ce..3c712efa9 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -10,7 +10,6 @@ import time import typing import unittest.mock -import uuid import warnings from abc import abstractmethod from typing import Any, Callable, Dict, List, Optional, Union, cast @@ -782,13 +781,15 @@ def _search( ":{}".format(self.task_type, dataset.task_type)) # Initialise information needed for the experiment - experiment_task_name = 'runSearch' + experiment_task_name: str = 'runSearch' dataset_requirements = get_dataset_requirements( info=self._get_required_dataset_properties(dataset)) self._dataset_requirements = dataset_requirements dataset_properties = dataset.get_dataset_properties(dataset_requirements) self._stopwatch.start_task(experiment_task_name) self.dataset_name = dataset.dataset_name + assert self.dataset_name is not None + if self._logger is None: self._logger = self._get_logger(self.dataset_name) self._all_supported_metrics = all_supported_metrics @@ -897,7 +898,7 @@ def _search( start_time=time.time(), time_left_for_ensembles=time_left_for_ensembles, backend=copy.deepcopy(self._backend), - dataset_name=dataset.dataset_name, + dataset_name=str(dataset.dataset_name), output_type=STRING_TO_OUTPUT_TYPES[dataset.output_type], task_type=STRING_TO_TASK_TYPES[self.task_type], metrics=[self._metric], @@ -916,7 +917,7 @@ def _search( self._stopwatch.stop_task(ensemble_task_name) # ==> Run SMAC - smac_task_name = 'runSMAC' + smac_task_name: str = 'runSMAC' self._stopwatch.start_task(smac_task_name) elapsed_time = self._stopwatch.wall_elapsed(experiment_task_name) time_left_for_smac = max(0, total_walltime_limit - elapsed_time) @@ -928,7 +929,7 @@ def _search( _proc_smac = AutoMLSMBO( config_space=self.search_space, - dataset_name=dataset.dataset_name, + dataset_name=str(dataset.dataset_name), backend=self._backend, total_walltime_limit=total_walltime_limit, func_eval_time_limit_secs=func_eval_time_limit_secs, @@ -1035,11 +1036,11 @@ def refit( Returns: self """ - if self.dataset_name is None: - self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) + + self.dataset_name = dataset.dataset_name if self._logger is None: - self._logger = self._get_logger(self.dataset_name) + self._logger = self._get_logger(str(self.dataset_name)) dataset_requirements = get_dataset_requirements( info=self._get_required_dataset_properties(dataset)) @@ -1105,11 +1106,10 @@ def fit(self, Returns: (BasePipeline): fitted pipeline """ - if self.dataset_name is None: - self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) + self.dataset_name = dataset.dataset_name if self._logger is None: - self._logger = self._get_logger(self.dataset_name) + self._logger = self._get_logger(str(self.dataset_name)) # get dataset properties dataset_requirements = get_dataset_requirements( diff --git a/autoPyTorch/datasets/base_dataset.py b/autoPyTorch/datasets/base_dataset.py index 4c19fa17d..2f99e54f7 100644 --- a/autoPyTorch/datasets/base_dataset.py +++ b/autoPyTorch/datasets/base_dataset.py @@ -1,3 +1,5 @@ +import os +import uuid from abc import ABCMeta from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast @@ -13,18 +15,17 @@ from autoPyTorch.constants import CLASSIFICATION_OUTPUTS, STRING_TO_OUTPUT_TYPES from autoPyTorch.datasets.resampling_strategy import ( - CROSS_VAL_FN, + CrossValFunc, + CrossValFuncs, CrossValTypes, DEFAULT_RESAMPLING_PARAMETERS, - HOLDOUT_FN, - HoldoutValTypes, - get_cross_validators, - get_holdout_validators, - is_stratified, + HoldOutFunc, + HoldOutFuncs, + HoldoutValTypes ) -from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix +from autoPyTorch.utils.common import FitRequirement -BaseDatasetType = Union[Tuple[np.ndarray, np.ndarray], Dataset] +BaseDatasetInputType = Union[Tuple[np.ndarray, np.ndarray], Dataset] def check_valid_data(data: Any) -> None: @@ -33,7 +34,8 @@ def check_valid_data(data: Any) -> None: 'The specified Data for Dataset must have both __getitem__ and __len__ attribute.') -def type_check(train_tensors: BaseDatasetType, val_tensors: Optional[BaseDatasetType] = None) -> None: +def type_check(train_tensors: BaseDatasetInputType, + val_tensors: Optional[BaseDatasetInputType] = None) -> None: """To avoid unexpected behavior, we use loops over indices.""" for i in range(len(train_tensors)): check_valid_data(train_tensors[i]) @@ -49,8 +51,8 @@ class TransformSubset(Subset): we require different transformation for each data point. This class helps to take the subset of the dataset with either training or validation transformation. - - We achieve so by adding a train flag to the pytorch subset + The TransformSubset allows to add train flags + while indexing the main dataset towards this goal. Attributes: dataset (BaseDataset/Dataset): Dataset to sample the subset @@ -71,10 +73,10 @@ def __getitem__(self, idx: int) -> np.ndarray: class BaseDataset(Dataset, metaclass=ABCMeta): def __init__( self, - train_tensors: BaseDatasetType, + train_tensors: BaseDatasetInputType, dataset_name: Optional[str] = None, - val_tensors: Optional[BaseDatasetType] = None, - test_tensors: Optional[BaseDatasetType] = None, + val_tensors: Optional[BaseDatasetInputType] = None, + test_tensors: Optional[BaseDatasetInputType] = None, resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, resampling_strategy_args: Optional[Dict[str, Any]] = None, shuffle: Optional[bool] = True, @@ -106,14 +108,16 @@ def __init__( val_transforms (Optional[torchvision.transforms.Compose]): Additional Transforms to be applied to the validation/test data """ - self.dataset_name = dataset_name if dataset_name is not None \ - else hash_array_or_matrix(train_tensors[0]) + self.dataset_name = dataset_name + + if self.dataset_name is None: + self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) if not hasattr(train_tensors[0], 'shape'): type_check(train_tensors, val_tensors) self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors - self.cross_validators: Dict[str, CROSS_VAL_FN] = {} - self.holdout_validators: Dict[str, HOLDOUT_FN] = {} + self.cross_validators: Dict[str, CrossValFunc] = {} + self.holdout_validators: Dict[str, HoldOutFunc] = {} self.rng = np.random.RandomState(seed=seed) self.shuffle = shuffle self.resampling_strategy = resampling_strategy @@ -134,8 +138,8 @@ def __init__( self.is_small_preprocess = True # Make sure cross validation splits are created once - self.cross_validators = get_cross_validators(*CrossValTypes) - self.holdout_validators = get_holdout_validators(*HoldoutValTypes) + self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes) + self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes) self.splits = self.get_splits_from_resampling_strategy() # We also need to be able to transform the data, be it for pre-processing @@ -263,7 +267,7 @@ def create_cross_val_splits( if not isinstance(cross_val_type, CrossValTypes): raise NotImplementedError(f'The selected `cross_val_type` "{cross_val_type}" is not implemented.') kwargs = {} - if is_stratified(cross_val_type): + if cross_val_type.is_stratified(): # we need additional information about the data for stratification kwargs["stratify"] = self.train_tensors[-1] splits = self.cross_validators[cross_val_type.name]( @@ -298,7 +302,7 @@ def create_holdout_val_split( if not isinstance(holdout_val_type, HoldoutValTypes): raise NotImplementedError(f'The specified `holdout_val_type` "{holdout_val_type}" is not supported.') kwargs = {} - if is_stratified(holdout_val_type): + if holdout_val_type.is_stratified(): # we need additional information about the data for stratification kwargs["stratify"] = self.train_tensors[-1] train, val = self.holdout_validators[holdout_val_type.name](val_share, self._get_indices(), **kwargs) @@ -321,7 +325,8 @@ def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]: return (TransformSubset(self, self.splits[split_id][0], train=True), TransformSubset(self, self.splits[split_id][1], train=False)) - def replace_data(self, X_train: BaseDatasetType, X_test: Optional[BaseDatasetType]) -> 'BaseDataset': + def replace_data(self, X_train: BaseDatasetInputType, + X_test: Optional[BaseDatasetInputType]) -> 'BaseDataset': """ To speed up the training of small dataset, early pre-processing of the data can be made on the fly by the pipeline. diff --git a/autoPyTorch/datasets/resampling_strategy.py b/autoPyTorch/datasets/resampling_strategy.py index b853fac0a..765a31cdb 100644 --- a/autoPyTorch/datasets/resampling_strategy.py +++ b/autoPyTorch/datasets/resampling_strategy.py @@ -16,7 +16,7 @@ # Use callback protocol as workaround, since callable with function fields count 'self' as argument -class CROSS_VAL_FN(Protocol): +class CrossValFunc(Protocol): def __call__(self, num_splits: int, indices: np.ndarray, @@ -24,25 +24,57 @@ def __call__(self, ... -class HOLDOUT_FN(Protocol): +class HoldOutFunc(Protocol): def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any] ) -> Tuple[np.ndarray, np.ndarray]: ... class CrossValTypes(IntEnum): + """The type of cross validation + + This class is used to specify the cross validation function + and is not supposed to be instantiated. + + Examples: This class is supposed to be used as follows + >>> cv_type = CrossValTypes.k_fold_cross_validation + >>> print(cv_type.name) + + k_fold_cross_validation + + >>> for cross_val_type in CrossValTypes: + print(cross_val_type.name, cross_val_type.value) + + stratified_k_fold_cross_validation 1 + k_fold_cross_validation 2 + stratified_shuffle_split_cross_validation 3 + shuffle_split_cross_validation 4 + time_series_cross_validation 5 + """ stratified_k_fold_cross_validation = 1 k_fold_cross_validation = 2 stratified_shuffle_split_cross_validation = 3 shuffle_split_cross_validation = 4 time_series_cross_validation = 5 + def is_stratified(self) -> bool: + stratified = [self.stratified_k_fold_cross_validation, + self.stratified_shuffle_split_cross_validation] + return getattr(self, self.name) in stratified + class HoldoutValTypes(IntEnum): + """TODO: change to enum using functools.partial""" + """The type of hold out validation (refer to CrossValTypes' doc-string)""" holdout_validation = 6 stratified_holdout_validation = 7 + def is_stratified(self) -> bool: + stratified = [self.stratified_holdout_validation] + return getattr(self, self.name) in stratified + +# TODO: replace it with another way RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes] DEFAULT_RESAMPLING_PARAMETERS = { @@ -67,87 +99,111 @@ class HoldoutValTypes(IntEnum): } # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]] -def get_cross_validators(*cross_val_types: CrossValTypes) -> Dict[str, CROSS_VAL_FN]: - cross_validators = {} # type: Dict[str, CROSS_VAL_FN] - for cross_val_type in cross_val_types: - cross_val_fn = globals()[cross_val_type.name] - cross_validators[cross_val_type.name] = cross_val_fn - return cross_validators - - -def get_holdout_validators(*holdout_val_types: HoldoutValTypes) -> Dict[str, HOLDOUT_FN]: - holdout_validators = {} # type: Dict[str, HOLDOUT_FN] - for holdout_val_type in holdout_val_types: - holdout_val_fn = globals()[holdout_val_type.name] - holdout_validators[holdout_val_type.name] = holdout_val_fn - return holdout_validators - - -def is_stratified(val_type: Union[str, CrossValTypes, HoldoutValTypes]) -> bool: - if isinstance(val_type, str): - return val_type.lower().startswith("stratified") - else: - return val_type.name.lower().startswith("stratified") - - -def holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) -> Tuple[np.ndarray, np.ndarray]: - train, val = train_test_split(indices, test_size=val_share, shuffle=False) - return train, val - - -def stratified_holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) \ - -> Tuple[np.ndarray, np.ndarray]: - train, val = train_test_split(indices, test_size=val_share, shuffle=True, stratify=kwargs["stratify"]) - return train, val - - -def shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ - -> List[Tuple[np.ndarray, np.ndarray]]: - cv = ShuffleSplit(n_splits=num_splits) - splits = list(cv.split(indices)) - return splits - - -def stratified_shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ - -> List[Tuple[np.ndarray, np.ndarray]]: - cv = StratifiedShuffleSplit(n_splits=num_splits) - splits = list(cv.split(indices, kwargs["stratify"])) - return splits - - -def stratified_k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ - -> List[Tuple[np.ndarray, np.ndarray]]: - cv = StratifiedKFold(n_splits=num_splits) - splits = list(cv.split(indices, kwargs["stratify"])) - return splits - - -def k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) -> List[Tuple[np.ndarray, np.ndarray]]: - """ - Standard k fold cross validation. - - :param indices: array of indices to be split - :param num_splits: number of cross validation splits - :return: list of tuples of training and validation indices - """ - cv = KFold(n_splits=num_splits) - splits = list(cv.split(indices)) - return splits - - -def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ - -> List[Tuple[np.ndarray, np.ndarray]]: - """ - Returns train and validation indices respecting the temporal ordering of the data. - Dummy example: [0, 1, 2, 3] with 3 folds yields - [0] [1] - [0, 1] [2] - [0, 1, 2] [3] - - :param indices: array of indices to be split - :param num_splits: number of cross validation splits - :return: list of tuples of training and validation indices - """ - cv = TimeSeriesSplit(n_splits=num_splits) - splits = list(cv.split(indices)) - return splits +class HoldOutFuncs(): + @staticmethod + def holdout_validation(val_share: float, + indices: np.ndarray, + **kwargs: Any + ) -> Tuple[np.ndarray, np.ndarray]: + train, val = train_test_split(indices, test_size=val_share, shuffle=False) + return train, val + + @staticmethod + def stratified_holdout_validation(val_share: float, + indices: np.ndarray, + **kwargs: Any + ) -> Tuple[np.ndarray, np.ndarray]: + train, val = train_test_split(indices, test_size=val_share, shuffle=True, stratify=kwargs["stratify"]) + return train, val + + @classmethod + def get_holdout_validators(cls, *holdout_val_types: HoldoutValTypes) -> Dict[str, HoldOutFunc]: + + holdout_validators = { + holdout_val_type.name: getattr(cls, holdout_val_type.name) + for holdout_val_type in holdout_val_types + } + return holdout_validators + + +class CrossValFuncs(): + @staticmethod + def shuffle_split_cross_validation(num_splits: int, + indices: np.ndarray, + **kwargs: Any + ) -> List[Tuple[np.ndarray, np.ndarray]]: + cv = ShuffleSplit(n_splits=num_splits) + splits = list(cv.split(indices)) + return splits + + @staticmethod + def stratified_shuffle_split_cross_validation(num_splits: int, + indices: np.ndarray, + **kwargs: Any + ) -> List[Tuple[np.ndarray, np.ndarray]]: + cv = StratifiedShuffleSplit(n_splits=num_splits) + splits = list(cv.split(indices, kwargs["stratify"])) + return splits + + @staticmethod + def stratified_k_fold_cross_validation(num_splits: int, + indices: np.ndarray, + **kwargs: Any + ) -> List[Tuple[np.ndarray, np.ndarray]]: + cv = StratifiedKFold(n_splits=num_splits) + splits = list(cv.split(indices, kwargs["stratify"])) + return splits + + @staticmethod + def k_fold_cross_validation(num_splits: int, + indices: np.ndarray, + **kwargs: Any + ) -> List[Tuple[np.ndarray, np.ndarray]]: + """ + Standard k fold cross validation. + + Args: + indices (np.ndarray): array of indices to be split + num_splits (int): number of cross validation splits + + Returns: + splits (List[Tuple[List, List]]): list of tuples of training and validation indices + """ + cv = KFold(n_splits=num_splits) + splits = list(cv.split(indices)) + return splits + + @staticmethod + def time_series_cross_validation(num_splits: int, + indices: np.ndarray, + **kwargs: Any + ) -> List[Tuple[np.ndarray, np.ndarray]]: + """ + Returns train and validation indices respecting the temporal ordering of the data. + + Args: + indices (np.ndarray): array of indices to be split + num_splits (int): number of cross validation splits + + Returns: + splits (List[Tuple[List, List]]): list of tuples of training and validation indices + + Examples: + >>> indices = np.array([0, 1, 2, 3]) + >>> CrossValFuncs.time_series_cross_validation(3, indices) + [([0], [1]), + ([0, 1], [2]), + ([0, 1, 2], [3])] + + """ + cv = TimeSeriesSplit(n_splits=num_splits) + splits = list(cv.split(indices)) + return splits + + @classmethod + def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, CrossValFunc]: + cross_validators = { + cross_val_type.name: getattr(cls, cross_val_type.name) + for cross_val_type in cross_val_types + } + return cross_validators diff --git a/autoPyTorch/datasets/time_series_dataset.py b/autoPyTorch/datasets/time_series_dataset.py index 7b0435d19..edd07a80e 100644 --- a/autoPyTorch/datasets/time_series_dataset.py +++ b/autoPyTorch/datasets/time_series_dataset.py @@ -6,10 +6,10 @@ from autoPyTorch.datasets.base_dataset import BaseDataset from autoPyTorch.datasets.resampling_strategy import ( + CrossValFuncs, CrossValTypes, - HoldoutValTypes, - get_cross_validators, - get_holdout_validators + HoldOutFuncs, + HoldoutValTypes ) TIME_SERIES_FORECASTING_INPUT = Tuple[np.ndarray, np.ndarray] # currently only numpy arrays are supported @@ -60,8 +60,8 @@ def __init__(self, train_transforms=train_transforms, val_transforms=val_transforms, ) - self.cross_validators = get_cross_validators(CrossValTypes.time_series_cross_validation) - self.holdout_validators = get_holdout_validators(HoldoutValTypes.holdout_validation) + self.cross_validators = CrossValFuncs.get_cross_validators(CrossValTypes.time_series_cross_validation) + self.holdout_validators = HoldOutFuncs.get_holdout_validators(HoldoutValTypes.holdout_validation) def _check_time_series_forecasting_inputs(target_variables: Tuple[int], @@ -117,13 +117,13 @@ def __init__(self, val=val, task_type="time_series_classification") super().__init__(train_tensors=train, val_tensors=val, shuffle=True) - self.cross_validators = get_cross_validators( + self.cross_validators = CrossValFuncs.get_cross_validators( CrossValTypes.stratified_k_fold_cross_validation, CrossValTypes.k_fold_cross_validation, CrossValTypes.shuffle_split_cross_validation, CrossValTypes.stratified_shuffle_split_cross_validation ) - self.holdout_validators = get_holdout_validators( + self.holdout_validators = HoldOutFuncs.get_holdout_validators( HoldoutValTypes.holdout_validation, HoldoutValTypes.stratified_holdout_validation ) @@ -135,11 +135,11 @@ def __init__(self, train: Tuple[np.ndarray, np.ndarray], val: Optional[Tuple[np. val=val, task_type="time_series_regression") super().__init__(train_tensors=train, val_tensors=val, shuffle=True) - self.cross_validators = get_cross_validators( + self.cross_validators = CrossValFuncs.get_cross_validators( CrossValTypes.k_fold_cross_validation, CrossValTypes.shuffle_split_cross_validation ) - self.holdout_validators = get_holdout_validators( + self.holdout_validators = HoldOutFuncs.get_holdout_validators( HoldoutValTypes.holdout_validation ) diff --git a/autoPyTorch/ensemble/ensemble_builder.py b/autoPyTorch/ensemble/ensemble_builder.py index 434849ef1..e236f091b 100644 --- a/autoPyTorch/ensemble/ensemble_builder.py +++ b/autoPyTorch/ensemble/ensemble_builder.py @@ -66,57 +66,56 @@ def __init__( logger_port: int = logging.handlers.DEFAULT_TCP_LOGGING_PORT, ): """ SMAC callback to handle ensemble building - Parameters - ---------- - start_time: int - the time when this job was started, to account for any latency in job allocation - time_left_for_ensemble: int - How much time is left for the task. Job should finish within this allocated time - backend: util.backend.Backend - backend to write and read files - dataset_name: str - name of dataset - task_type: int - what type of output is expected. If Binary, we need to argmax the one hot encoding. - metrics: List[autoPyTorchMetric], - A set of metrics that will be used to get performance estimates - opt_metric: str - name of the optimization metrics - ensemble_size: int - maximal size of ensemble (passed to ensemble_selection) - ensemble_nbest: int/float - if int: consider only the n best prediction - if float: consider only this fraction of the best models - Both wrt to validation predictions - If performance_range_threshold > 0, might return less models - max_models_on_disc: Union[float, int] - Defines the maximum number of models that are kept in the disc. - If int, it must be greater or equal than 1, and dictates the max number of - models to keep. - If float, it will be interpreted as the max megabytes allowed of disc space. That - is, if the number of ensemble candidates require more disc space than this float - value, the worst models will be deleted to keep within this budget. - Models and predictions of the worst-performing models will be deleted then. - If None, the feature is disabled. - It defines an upper bound on the models that can be used in the ensemble. - seed: int - random seed - max_iterations: int - maximal number of iterations to run this script - (default None --> deactivated) - precision: [16,32,64,128] - precision of floats to read the predictions - memory_limit: Optional[int] - memory limit in mb. If ``None``, no memory limit is enforced. - read_at_most: int - read at most n new prediction files in each iteration - logger_port: int - port in where to publish a msg - Returns - ------- - List[Tuple[int, float, float, float]]: - A list with the performance history of this ensemble, of the form - [[pandas_timestamp, train_performance, val_performance, test_performance], ...] + Args: + start_time: int + the time when this job was started, to account for any latency in job allocation + time_left_for_ensemble: int + How much time is left for the task. Job should finish within this allocated time + backend: util.backend.Backend + backend to write and read files + dataset_name: str + name of dataset + task_type: int + what type of output is expected. If Binary, we need to argmax the one hot encoding. + metrics: List[autoPyTorchMetric], + A set of metrics that will be used to get performance estimates + opt_metric: str + name of the optimization metrics + ensemble_size: int + maximal size of ensemble (passed to ensemble_selection) + ensemble_nbest: int/float + if int: consider only the n best prediction + if float: consider only this fraction of the best models + Both wrt to validation predictions + If performance_range_threshold > 0, might return less models + max_models_on_disc: Union[float, int] + Defines the maximum number of models that are kept in the disc. + If int, it must be greater or equal than 1, and dictates the max number of + models to keep. + If float, it will be interpreted as the max megabytes allowed of disc space. That + is, if the number of ensemble candidates require more disc space than this float + value, the worst models will be deleted to keep within this budget. + Models and predictions of the worst-performing models will be deleted then. + If None, the feature is disabled. + It defines an upper bound on the models that can be used in the ensemble. + seed: int + random seed + max_iterations: int + maximal number of iterations to run this script + (default None --> deactivated) + precision: [16,32,64,128] + precision of floats to read the predictions + memory_limit: Optional[int] + memory limit in mb. If ``None``, no memory limit is enforced. + read_at_most: int + read at most n new prediction files in each iteration + logger_port: int + port in where to publish a msg + + Returns: + List[Tuple[int, float, float, float]]: + A list with the performance history of this ensemble, of the form + [[pandas_timestamp, train_performance, val_performance, test_performance], ...] """ self.start_time = start_time self.time_left_for_ensembles = time_left_for_ensembles