From a11caf417145d2e5dc9ee9c5b968a48f5653a2f2 Mon Sep 17 00:00:00 2001 From: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com> Date: Thu, 4 Nov 2021 13:11:01 +0100 Subject: [PATCH] [ADD] Add column transformer (#305) * Match paper libraries-versions * Update README.md * Update README.md * Update README.md * [FIX] master branch README (#209) * Enable github actions (#273) * Update README.md * Create CITATION.cff * Added column transformer, changed requirements and added tests * remove redundant lines * Remove unwanted change made * Fix bug in test api and dummy forward pass * Fix silly bugs * increase time to pass test * remove parallel capabilities of traditional learners to resolve bug in docs building * almost fixed * Add documentation for tabularfeaturevalidator * fix flake * fix silly bug * address comment from shuhei * rename enc_columns to transformed_columns in the remaining places * fix bug in test * fix mypy * add shuhei's suggestion Co-authored-by: chico Co-authored-by: Marius Lindauer Co-authored-by: Frank Co-authored-by: Francisco Rivera Valverde <44504424+franchuterivera@users.noreply.github.com> --- autoPyTorch/api/base_task.py | 3 +- autoPyTorch/data/base_feature_validator.py | 10 +- autoPyTorch/data/tabular_feature_validator.py | 225 ++++++++++-------- .../network_backbone/InceptionTimeBackbone.py | 2 +- .../estimator_configs/extra_trees.json | 3 +- .../traditional_ml/estimator_configs/knn.json | 3 +- .../traditional_ml/estimator_configs/lgb.json | 3 +- .../estimator_configs/random_forest.json | 3 +- .../estimator_configs/rotation_forest.json | 1 - .../example_custom_configuration_space.py | 12 +- requirements.txt | 4 +- test/test_api/test_base_api.py | 2 +- test/test_data/test_feature_validator.py | 54 ++++- 13 files changed, 188 insertions(+), 137 deletions(-) diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index 94add94bd..2ab5650b1 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -1322,7 +1322,8 @@ def get_incumbent_results( if not include_traditional: # traditional classifiers have trainer_configuration in their additional info run_history_data = dict( - filter(lambda elem: elem[1].additional_info is not None and elem[1]. + filter(lambda elem: elem[1].status == StatusType.SUCCESS and elem[1]. + additional_info is not None and elem[1]. additional_info['configuration_origin'] != 'traditional', run_history_data.items())) run_history_data = dict( diff --git a/autoPyTorch/data/base_feature_validator.py b/autoPyTorch/data/base_feature_validator.py index 2ef02ceba..fe4063611 100644 --- a/autoPyTorch/data/base_feature_validator.py +++ b/autoPyTorch/data/base_feature_validator.py @@ -35,11 +35,6 @@ class BaseFeatureValidator(BaseEstimator): List of the column types found by this estimator during fit. data_type (str): Class name of the data type provided during fit. - encoder (typing.Optional[BaseEstimator]) - Host a encoder object if the data requires transformation (for example, - if provided a categorical column in a pandas DataFrame) - enc_columns (typing.List[str]) - List of columns that were encoded. """ def __init__(self, logger: typing.Optional[typing.Union[PicklableClientLogger, logging.Logger @@ -51,8 +46,8 @@ def __init__(self, self.dtypes = [] # type: typing.List[str] self.column_order = [] # type: typing.List[str] - self.encoder = None # type: typing.Optional[BaseEstimator] - self.enc_columns = [] # type: typing.List[str] + self.column_transformer = None # type: typing.Optional[BaseEstimator] + self.transformed_columns = [] # type: typing.List[str] self.logger: typing.Union[ PicklableClientLogger, logging.Logger @@ -61,6 +56,7 @@ def __init__(self, # Required for dataset properties self.num_features = None # type: typing.Optional[int] self.categories = [] # type: typing.List[typing.List[int]] + self.categorical_columns: typing.List[int] = [] self.numerical_columns: typing.List[int] = [] diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index 1a8ce5fa8..61dc60d6a 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -1,5 +1,6 @@ import functools import typing +from typing import Dict, List import numpy as np @@ -13,11 +14,108 @@ from sklearn.base import BaseEstimator from sklearn.compose import ColumnTransformer from sklearn.exceptions import NotFittedError +from sklearn.impute import SimpleImputer +from sklearn.pipeline import make_pipeline from autoPyTorch.data.base_feature_validator import BaseFeatureValidator, SUPPORTED_FEAT_TYPES +def _create_column_transformer( + preprocessors: Dict[str, List[BaseEstimator]], + categorical_columns: List[str], +) -> ColumnTransformer: + """ + Given a dictionary of preprocessors, this function + creates a sklearn column transformer with appropriate + columns associated with their preprocessors. + + Args: + preprocessors (Dict[str, List[BaseEstimator]]): + Dictionary containing list of numerical and categorical preprocessors. + categorical_columns (List[str]): + List of names of categorical columns + + Returns: + ColumnTransformer + """ + + categorical_pipeline = make_pipeline(*preprocessors['categorical']) + + return ColumnTransformer([ + ('categorical_pipeline', categorical_pipeline, categorical_columns)], + remainder='passthrough' + ) + + +def get_tabular_preprocessors() -> Dict[str, List[BaseEstimator]]: + """ + This function creates a Dictionary containing a list + of numerical and categorical preprocessors + + Returns: + Dict[str, List[BaseEstimator]] + """ + preprocessors: Dict[str, List[BaseEstimator]] = dict() + + # Categorical Preprocessors + onehot_encoder = preprocessing.OrdinalEncoder(handle_unknown='use_encoded_value', + unknown_value=-1) + categorical_imputer = SimpleImputer(strategy='constant', copy=False) + + preprocessors['categorical'] = [categorical_imputer, onehot_encoder] + + return preprocessors + + class TabularFeatureValidator(BaseFeatureValidator): + """ + A subclass of `BaseFeatureValidator` made for tabular data. + It ensures that the dataset provided is of the expected format. + Subsequently, it preprocesses the data by fitting a column + transformer. + + Attributes: + categories (List[List[str]]): + List for which an element at each index is a + list containing the categories for the respective + categorical column. + transformed_columns (List[str]) + List of columns that were transformed. + column_transformer (Optional[BaseEstimator]) + Hosts an imputer and an encoder object if the data + requires transformation (for example, if provided a + categorical column in a pandas DataFrame) + column_order (List[str]): + List of the features stored in the order that + was fitted. + numerical_columns (List[int]): + List of indices of numerical columns + categorical_columns (List[int]): + List of indices of categorical columns + """ + @staticmethod + def _comparator(cmp1: str, cmp2: str) -> int: + """Order so that categorical columns come left and numerical columns come right + + Args: + cmp1 (str): First variable to compare + cmp2 (str): Second variable to compare + + Raises: + ValueError: if the values of the variables to compare + are not in 'categorical' or 'numerical' + + Returns: + int: either [0, -1, 1] + """ + choices = ['categorical', 'numerical'] + if cmp1 not in choices or cmp2 not in choices: + raise ValueError('The comparator for the column order only accepts {}, ' + 'but got {} and {}'.format(choices, cmp1, cmp2)) + + idx1, idx2 = choices.index(cmp1), choices.index(cmp2) + return idx1 - idx2 + def _fit( self, X: SUPPORTED_FEAT_TYPES, @@ -60,51 +158,38 @@ def _fit( if not X.select_dtypes(include='object').empty: X = self.infer_objects(X) - self.enc_columns, self.feat_type = self._get_columns_to_encode(X) + self.transformed_columns, self.feat_type = self._get_columns_to_encode(X) - if len(self.enc_columns) > 0: - X = self.impute_nan_in_categories(X) + assert self.feat_type is not None - self.encoder = ColumnTransformer( - [ - ("encoder", - preprocessing.OrdinalEncoder( - handle_unknown='use_encoded_value', - unknown_value=-1, - ), self.enc_columns)], - remainder="passthrough" + if len(self.transformed_columns) > 0: + + preprocessors = get_tabular_preprocessors() + self.column_transformer = _create_column_transformer( + preprocessors=preprocessors, + categorical_columns=self.transformed_columns, ) # Mypy redefinition - assert self.encoder is not None - self.encoder.fit(X) - - # The column transformer reoders the feature types - we therefore need to change - # it as well - # This means columns are shifted to the right - def comparator(cmp1: str, cmp2: str) -> int: - if ( - cmp1 == 'categorical' and cmp2 == 'categorical' - or cmp1 == 'numerical' and cmp2 == 'numerical' - ): - return 0 - elif cmp1 == 'categorical' and cmp2 == 'numerical': - return -1 - elif cmp1 == 'numerical' and cmp2 == 'categorical': - return 1 - else: - raise ValueError((cmp1, cmp2)) + assert self.column_transformer is not None + self.column_transformer.fit(X) + # The column transformer reorders the feature types + # therefore, we need to change the order of columns as well + # This means categorical columns are shifted to the left self.feat_type = sorted( self.feat_type, - key=functools.cmp_to_key(comparator) + key=functools.cmp_to_key(self._comparator) ) + encoded_categories = self.column_transformer.\ + named_transformers_['categorical_pipeline'].\ + named_steps['ordinalencoder'].categories_ self.categories = [ # We fit an ordinal encoder, where all categorical # columns are shifted to the left list(range(len(cat))) - for cat in self.encoder.transformers_[0][1].categories_ + for cat in encoded_categories ] for i, type_ in enumerate(self.feat_type): @@ -158,7 +243,7 @@ def transform( self._check_data(X) # Pandas related transformations - if hasattr(X, "iloc") and self.encoder is not None: + if hasattr(X, "iloc") and self.column_transformer is not None: if np.any(pd.isnull(X)): # After above check it means that if there is a NaN # the whole column must be NaN @@ -167,11 +252,7 @@ def transform( if X[column].isna().all(): X[column] = pd.to_numeric(X[column]) - # We also need to fillna on the transformation - # in case test data is provided - X = self.impute_nan_in_categories(X) - - X = self.encoder.transform(X) + X = self.column_transformer.transform(X) # Sparse related transformations # Not all sparse format support index sorting @@ -245,7 +326,7 @@ def _check_data( # Define the column to be encoded here as the feature validator is fitted once # per estimator - enc_columns, _ = self._get_columns_to_encode(X) + self.transformed_columns, self.feat_type = self._get_columns_to_encode(X) column_order = [column for column in X.columns] if len(self.column_order) > 0: @@ -282,13 +363,17 @@ def _get_columns_to_encode( A set of features that are going to be validated (type and dimensionality checks) and a encoder fitted in the case the data needs encoding Returns: - enc_columns (List[str]): + transformed_columns (List[str]): Columns to encode, if any feat_type: Type of each column numerical/categorical """ + + if len(self.transformed_columns) > 0 and self.feat_type is not None: + return self.transformed_columns, self.feat_type + # Register if a column needs encoding - enc_columns = [] + transformed_columns = [] # Also, register the feature types for the estimator feat_type = [] @@ -297,7 +382,7 @@ def _get_columns_to_encode( for i, column in enumerate(X.columns): if X[column].dtype.name in ['category', 'bool']: - enc_columns.append(column) + transformed_columns.append(column) feat_type.append('categorical') # Move away from np.issubdtype as it causes # TypeError: data type not understood in certain pandas types @@ -339,7 +424,7 @@ def _get_columns_to_encode( ) else: feat_type.append('numerical') - return enc_columns, feat_type + return transformed_columns, feat_type def list_to_dataframe( self, @@ -429,59 +514,3 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame: self.object_dtype_mapping = {column: X[column].dtype for column in X.columns} self.logger.debug(f"Infer Objects: {self.object_dtype_mapping}") return X - - def impute_nan_in_categories(self, X: pd.DataFrame) -> pd.DataFrame: - """ - impute missing values before encoding, - remove once sklearn natively supports - it in ordinal encoding. Sklearn issue: - "https://github.com/scikit-learn/scikit-learn/issues/17123)" - - Arguments: - X (pd.DataFrame): - data to be interpreted. - - Returns: - pd.DataFrame - """ - - # To be on the safe side, map always to the same missing - # value per column - if not hasattr(self, 'dict_nancol_to_missing'): - self.dict_missing_value_per_col: typing.Dict[str, typing.Any] = {} - - # First make sure that we do not alter the type of the column which cause: - # TypeError: '<' not supported between instances of 'int' and 'str' - # in the encoding - for column in self.enc_columns: - if X[column].isna().any(): - if column not in self.dict_missing_value_per_col: - try: - float(X[column].dropna().values[0]) - can_cast_as_number = True - except Exception: - can_cast_as_number = False - if can_cast_as_number: - # In this case, we expect to have a number as category - # it might be string, but its value represent a number - missing_value: typing.Union[str, int] = '-1' if isinstance(X[column].dropna().values[0], - str) else -1 - else: - missing_value = 'Missing!' - - # Make sure this missing value is not seen before - # Do this check for categorical columns - # else modify the value - if hasattr(X[column], 'cat'): - while missing_value in X[column].cat.categories: - if isinstance(missing_value, str): - missing_value += '0' - else: - missing_value += missing_value - self.dict_missing_value_per_col[column] = missing_value - - # Convert the frame in place - X[column].cat.add_categories([self.dict_missing_value_per_col[column]], - inplace=True) - X.fillna({column: self.dict_missing_value_per_col[column]}, inplace=True) - return X diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/InceptionTimeBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/InceptionTimeBackbone.py index a574b5237..869f808ed 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/InceptionTimeBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/InceptionTimeBackbone.py @@ -79,7 +79,7 @@ def __init__(self, n_res_inputs: int, n_outputs: int): def forward(self, x: torch.Tensor, res: torch.Tensor) -> torch.Tensor: shortcut = self.shortcut(res) shortcut = self.bn(shortcut) - x += shortcut + x = x + shortcut return torch.relu(x) diff --git a/autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/extra_trees.json b/autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/extra_trees.json index 29faff898..81f1d6383 100644 --- a/autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/extra_trees.json +++ b/autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/extra_trees.json @@ -1,4 +1,3 @@ { - "n_estimators" : 300, - "n_jobs" : -1 + "n_estimators" : 300 } diff --git a/autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/knn.json b/autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/knn.json index 65f7df535..0fa7f95d4 100644 --- a/autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/knn.json +++ b/autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/knn.json @@ -1,4 +1,3 @@ { - "weights" : "uniform", - "n_jobs" : -1 + "weights" : "uniform" } diff --git a/autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/lgb.json b/autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/lgb.json index 048fb5962..d8e061f5e 100644 --- a/autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/lgb.json +++ b/autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/lgb.json @@ -5,6 +5,5 @@ "min_data_in_leaf" : 3, "feature_fraction" : 0.9, "boosting_type" : "gbdt", - "learning_rate" : 0.03, - "num_threads" : -1 + "learning_rate" : 0.03 } diff --git a/autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/random_forest.json b/autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/random_forest.json index 29faff898..81f1d6383 100644 --- a/autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/random_forest.json +++ b/autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/random_forest.json @@ -1,4 +1,3 @@ { - "n_estimators" : 300, - "n_jobs" : -1 + "n_estimators" : 300 } diff --git a/autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/rotation_forest.json b/autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/rotation_forest.json index 5f582d9ae..2c63c0851 100644 --- a/autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/rotation_forest.json +++ b/autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/rotation_forest.json @@ -1,3 +1,2 @@ { - "n_jobs" : -1 } diff --git a/examples/40_advanced/example_custom_configuration_space.py b/examples/40_advanced/example_custom_configuration_space.py index bd02e51f1..f552045c1 100644 --- a/examples/40_advanced/example_custom_configuration_space.py +++ b/examples/40_advanced/example_custom_configuration_space.py @@ -72,11 +72,7 @@ def get_search_space_updates(): ############################################################################ # Build and fit a classifier with include components # ================================================== - # AutoPyTorch can search for multiple configurations at the same time - # if multiple cores are allocated, using the n_jobs argument. By default, - # Only 1 core is used while searching for configurations. api = TabularClassificationTask( - n_jobs=2, search_space_updates=get_search_space_updates(), include_components={'network_backbone': ['MLPBackbone', 'ResNetBackbone'], 'encoder': ['OneHotEncoder']} @@ -91,8 +87,8 @@ def get_search_space_updates(): X_test=X_test.copy(), y_test=y_test.copy(), optimize_metric='accuracy', - total_walltime_limit=300, - func_eval_time_limit_secs=50 + total_walltime_limit=150, + func_eval_time_limit_secs=30 ) ############################################################################ @@ -122,8 +118,8 @@ def get_search_space_updates(): X_test=X_test.copy(), y_test=y_test.copy(), optimize_metric='accuracy', - total_walltime_limit=300, - func_eval_time_limit_secs=50 + total_walltime_limit=150, + func_eval_time_limit_secs=30 ) ############################################################################ diff --git a/requirements.txt b/requirements.txt index a2f23958f..6f81bfcb7 100755 --- a/requirements.txt +++ b/requirements.txt @@ -4,13 +4,13 @@ torchvision tensorboard scikit-learn>=0.24.0,<0.25.0 numpy -scipy>=0.14.1,<1.7.0 +scipy>=1.7 lockfile imgaug>=0.4.0 ConfigSpace>=0.4.14,<0.5 pynisher>=0.6.3 pyrfr>=0.7,<0.9 -smac>=0.13.1,<0.14 +smac==0.14.0 dask distributed>=2.2.0 catboost diff --git a/test/test_api/test_base_api.py b/test/test_api/test_base_api.py index 6c74515b1..126b702e6 100644 --- a/test/test_api/test_base_api.py +++ b/test/test_api/test_base_api.py @@ -134,7 +134,7 @@ def test_pipeline_get_budget(fit_dictionary_tabular, min_budget, max_budget, bud estimator._search(optimize_metric='accuracy', dataset=dataset, tae_func=pipeline_fit, min_budget=min_budget, max_budget=max_budget, budget_type=budget_type, enable_traditional_pipeline=False, - total_walltime_limit=10, func_eval_time_limit_secs=5, + total_walltime_limit=20, func_eval_time_limit_secs=10, load_models=False) assert list(smac_mock.call_args)[1]['ta_kwargs']['pipeline_config'] == default_pipeline_config assert list(smac_mock.call_args)[1]['max_budget'] == max_budget diff --git a/test/test_data/test_feature_validator.py b/test/test_data/test_feature_validator.py index b5430e5fc..7f2ff2507 100644 --- a/test/test_data/test_feature_validator.py +++ b/test/test_data/test_feature_validator.py @@ -1,4 +1,5 @@ import copy +import functools import numpy as np @@ -236,8 +237,10 @@ def test_featurevalidator_categorical_nan(input_data_featuretest): validator.fit(input_data_featuretest) transformed_X = validator.transform(input_data_featuretest) assert any(pd.isna(input_data_featuretest)) - assert any((-1 in categories) or ('-1' in categories) or ('Missing!' in categories) for categories in - validator.encoder.named_transformers_['encoder'].categories_) + categories_ = validator.column_transformer.named_transformers_['categorical_pipeline'].\ + named_steps['ordinalencoder'].categories_ + assert any(('0' in categories) or (0 in categories) or ('missing_value' in categories) for categories in + categories_) assert np.shape(input_data_featuretest) == np.shape(transformed_X) assert np.issubdtype(transformed_X.dtype, np.number) assert validator._is_fitted @@ -311,9 +314,9 @@ def test_featurevalidator_get_columns_to_encode(): for col in df.columns: df[col] = df[col].astype(col) - enc_columns, feature_types = validator._get_columns_to_encode(df) + transformed_columns, feature_types = validator._get_columns_to_encode(df) - assert enc_columns == ['category', 'bool'] + assert transformed_columns == ['category', 'bool'] assert feature_types == ['numerical', 'numerical', 'categorical', 'categorical'] @@ -371,14 +374,14 @@ def test_features_unsupported_calls_are_raised(): ), indirect=True ) -def test_no_encoder_created(input_data_featuretest): +def test_no_column_transformer_created(input_data_featuretest): """ Makes sure that for numerical only features, no encoder is created """ validator = TabularFeatureValidator() validator.fit(input_data_featuretest) validator.transform(input_data_featuretest) - assert validator.encoder is None + assert validator.column_transformer is None @pytest.mark.parametrize( @@ -389,18 +392,18 @@ def test_no_encoder_created(input_data_featuretest): ), indirect=True ) -def test_encoder_created(input_data_featuretest): +def test_column_transformer_created(input_data_featuretest): """ This test ensures an encoder is created if categorical data is provided """ validator = TabularFeatureValidator() validator.fit(input_data_featuretest) transformed_X = validator.transform(input_data_featuretest) - assert validator.encoder is not None + assert validator.column_transformer is not None # Make sure that the encoded features are actually encoded. Categorical columns are at # the start after transformation. In our fixtures, this is also honored prior encode - enc_columns, feature_types = validator._get_columns_to_encode(input_data_featuretest) + transformed_columns, feature_types = validator._get_columns_to_encode(input_data_featuretest) # At least one categorical assert 'categorical' in validator.feat_type @@ -515,7 +518,7 @@ def test_featurevalidator_new_data_after_fit(openml_id, if train_data_type == 'pandas': old_dtypes = copy.deepcopy(validator.dtypes) validator.dtypes = ['dummy' for dtype in X_train.dtypes] - with pytest.raises(ValueError, match=r"hanging the dtype of the features after fit"): + with pytest.raises(ValueError, match=r"Changing the dtype of the features after fit"): transformed_X = validator.transform(X_test) validator.dtypes = old_dtypes if test_data_type == 'pandas': @@ -523,3 +526,34 @@ def test_featurevalidator_new_data_after_fit(openml_id, X_test = X_test[reversed(columns)] with pytest.raises(ValueError, match=r"Changing the column order of the features"): transformed_X = validator.transform(X_test) + + +def test_comparator(): + numerical = 'numerical' + categorical = 'categorical' + + validator = TabularFeatureValidator + + with pytest.raises(ValueError, match=r"The comparator for the column order only accepts .*"): + dummy = 'dummy' + feat_type = [numerical, categorical, dummy] + feat_type = sorted( + feat_type, + key=functools.cmp_to_key(validator._comparator) + ) + + feat_type = [numerical, categorical] * 10 + ans = [categorical] * 10 + [numerical] * 10 + feat_type = sorted( + feat_type, + key=functools.cmp_to_key(validator._comparator) + ) + assert ans == feat_type + + feat_type = [numerical] * 10 + [categorical] * 10 + ans = [categorical] * 10 + [numerical] * 10 + feat_type = sorted( + feat_type, + key=functools.cmp_to_key(validator._comparator) + ) + assert ans == feat_type