Skip to content

Commit

Permalink
Refactoring base dataset splitting functions (automl#106)
Browse files Browse the repository at this point in the history
* [Fork from automl#105] Made CrossValFuncs and HoldOutFuncs class to group the functions

* Modified time_series_dataset.py to be compatible with resampling_strategy.py

* [fix]: back to the renamed version of CROSS_VAL_FN from temporal SplitFunc typing.

* fixed flake8 issues in three files

* fixed the flake8 issues

* [refactor] Address the francisco's comments

* [refactor] Adress the francisco's comments

* [refactor] Address the doc-string issue in TransformSubset class

* [fix] Address flake8 issues

* [fix] Fix flake8 issue

* [fix] Fix mypy issues raised by github check

* [fix] Fix a mypy issue

* [fix] Fix a contradiction in holdout_stratified_validation

Since stratified splitting requires to shuffle by default
and it raises error in the github check,
I fixed this issue.

* [fix] Address the francisco's review

* [fix] Fix a mypy issue tabular_dataset.py

* [fix] Address the francisco's comment about the self.dataset_name

Since we would to use the dataset name which does not have any name,
I decided to get self.dataset_name back to Optional[str].

* [fix] Fix mypy issues
  • Loading branch information
nabenabe0928 authored Apr 28, 2021
1 parent a4e08e2 commit fae72a4
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 180 deletions.
22 changes: 11 additions & 11 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import time
import typing
import unittest.mock
import uuid
import warnings
from abc import abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union, cast
Expand Down Expand Up @@ -782,13 +781,15 @@ def _search(
":{}".format(self.task_type, dataset.task_type))

# Initialise information needed for the experiment
experiment_task_name = 'runSearch'
experiment_task_name: str = 'runSearch'
dataset_requirements = get_dataset_requirements(
info=self._get_required_dataset_properties(dataset))
self._dataset_requirements = dataset_requirements
dataset_properties = dataset.get_dataset_properties(dataset_requirements)
self._stopwatch.start_task(experiment_task_name)
self.dataset_name = dataset.dataset_name
assert self.dataset_name is not None

if self._logger is None:
self._logger = self._get_logger(self.dataset_name)
self._all_supported_metrics = all_supported_metrics
Expand Down Expand Up @@ -897,7 +898,7 @@ def _search(
start_time=time.time(),
time_left_for_ensembles=time_left_for_ensembles,
backend=copy.deepcopy(self._backend),
dataset_name=dataset.dataset_name,
dataset_name=str(dataset.dataset_name),
output_type=STRING_TO_OUTPUT_TYPES[dataset.output_type],
task_type=STRING_TO_TASK_TYPES[self.task_type],
metrics=[self._metric],
Expand All @@ -916,7 +917,7 @@ def _search(
self._stopwatch.stop_task(ensemble_task_name)

# ==> Run SMAC
smac_task_name = 'runSMAC'
smac_task_name: str = 'runSMAC'
self._stopwatch.start_task(smac_task_name)
elapsed_time = self._stopwatch.wall_elapsed(experiment_task_name)
time_left_for_smac = max(0, total_walltime_limit - elapsed_time)
Expand All @@ -928,7 +929,7 @@ def _search(

_proc_smac = AutoMLSMBO(
config_space=self.search_space,
dataset_name=dataset.dataset_name,
dataset_name=str(dataset.dataset_name),
backend=self._backend,
total_walltime_limit=total_walltime_limit,
func_eval_time_limit_secs=func_eval_time_limit_secs,
Expand Down Expand Up @@ -1035,11 +1036,11 @@ def refit(
Returns:
self
"""
if self.dataset_name is None:
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))

self.dataset_name = dataset.dataset_name

if self._logger is None:
self._logger = self._get_logger(self.dataset_name)
self._logger = self._get_logger(str(self.dataset_name))

dataset_requirements = get_dataset_requirements(
info=self._get_required_dataset_properties(dataset))
Expand Down Expand Up @@ -1105,11 +1106,10 @@ def fit(self,
Returns:
(BasePipeline): fitted pipeline
"""
if self.dataset_name is None:
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
self.dataset_name = dataset.dataset_name

if self._logger is None:
self._logger = self._get_logger(self.dataset_name)
self._logger = self._get_logger(str(self.dataset_name))

# get dataset properties
dataset_requirements = get_dataset_requirements(
Expand Down
51 changes: 28 additions & 23 deletions autoPyTorch/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import uuid
from abc import ABCMeta
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast

Expand All @@ -13,18 +15,17 @@

from autoPyTorch.constants import CLASSIFICATION_OUTPUTS, STRING_TO_OUTPUT_TYPES
from autoPyTorch.datasets.resampling_strategy import (
CROSS_VAL_FN,
CrossValFunc,
CrossValFuncs,
CrossValTypes,
DEFAULT_RESAMPLING_PARAMETERS,
HOLDOUT_FN,
HoldoutValTypes,
get_cross_validators,
get_holdout_validators,
is_stratified,
HoldOutFunc,
HoldOutFuncs,
HoldoutValTypes
)
from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix
from autoPyTorch.utils.common import FitRequirement

BaseDatasetType = Union[Tuple[np.ndarray, np.ndarray], Dataset]
BaseDatasetInputType = Union[Tuple[np.ndarray, np.ndarray], Dataset]


def check_valid_data(data: Any) -> None:
Expand All @@ -33,7 +34,8 @@ def check_valid_data(data: Any) -> None:
'The specified Data for Dataset must have both __getitem__ and __len__ attribute.')


def type_check(train_tensors: BaseDatasetType, val_tensors: Optional[BaseDatasetType] = None) -> None:
def type_check(train_tensors: BaseDatasetInputType,
val_tensors: Optional[BaseDatasetInputType] = None) -> None:
"""To avoid unexpected behavior, we use loops over indices."""
for i in range(len(train_tensors)):
check_valid_data(train_tensors[i])
Expand All @@ -49,8 +51,8 @@ class TransformSubset(Subset):
we require different transformation for each data point.
This class helps to take the subset of the dataset
with either training or validation transformation.
We achieve so by adding a train flag to the pytorch subset
The TransformSubset allows to add train flags
while indexing the main dataset towards this goal.
Attributes:
dataset (BaseDataset/Dataset): Dataset to sample the subset
Expand All @@ -71,10 +73,10 @@ def __getitem__(self, idx: int) -> np.ndarray:
class BaseDataset(Dataset, metaclass=ABCMeta):
def __init__(
self,
train_tensors: BaseDatasetType,
train_tensors: BaseDatasetInputType,
dataset_name: Optional[str] = None,
val_tensors: Optional[BaseDatasetType] = None,
test_tensors: Optional[BaseDatasetType] = None,
val_tensors: Optional[BaseDatasetInputType] = None,
test_tensors: Optional[BaseDatasetInputType] = None,
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
shuffle: Optional[bool] = True,
Expand Down Expand Up @@ -106,14 +108,16 @@ def __init__(
val_transforms (Optional[torchvision.transforms.Compose]):
Additional Transforms to be applied to the validation/test data
"""
self.dataset_name = dataset_name if dataset_name is not None \
else hash_array_or_matrix(train_tensors[0])
self.dataset_name = dataset_name

if self.dataset_name is None:
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))

if not hasattr(train_tensors[0], 'shape'):
type_check(train_tensors, val_tensors)
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
self.cross_validators: Dict[str, CROSS_VAL_FN] = {}
self.holdout_validators: Dict[str, HOLDOUT_FN] = {}
self.cross_validators: Dict[str, CrossValFunc] = {}
self.holdout_validators: Dict[str, HoldOutFunc] = {}
self.rng = np.random.RandomState(seed=seed)
self.shuffle = shuffle
self.resampling_strategy = resampling_strategy
Expand All @@ -134,8 +138,8 @@ def __init__(
self.is_small_preprocess = True

# Make sure cross validation splits are created once
self.cross_validators = get_cross_validators(*CrossValTypes)
self.holdout_validators = get_holdout_validators(*HoldoutValTypes)
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes)
self.splits = self.get_splits_from_resampling_strategy()

# We also need to be able to transform the data, be it for pre-processing
Expand Down Expand Up @@ -263,7 +267,7 @@ def create_cross_val_splits(
if not isinstance(cross_val_type, CrossValTypes):
raise NotImplementedError(f'The selected `cross_val_type` "{cross_val_type}" is not implemented.')
kwargs = {}
if is_stratified(cross_val_type):
if cross_val_type.is_stratified():
# we need additional information about the data for stratification
kwargs["stratify"] = self.train_tensors[-1]
splits = self.cross_validators[cross_val_type.name](
Expand Down Expand Up @@ -298,7 +302,7 @@ def create_holdout_val_split(
if not isinstance(holdout_val_type, HoldoutValTypes):
raise NotImplementedError(f'The specified `holdout_val_type` "{holdout_val_type}" is not supported.')
kwargs = {}
if is_stratified(holdout_val_type):
if holdout_val_type.is_stratified():
# we need additional information about the data for stratification
kwargs["stratify"] = self.train_tensors[-1]
train, val = self.holdout_validators[holdout_val_type.name](val_share, self._get_indices(), **kwargs)
Expand All @@ -321,7 +325,8 @@ def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
return (TransformSubset(self, self.splits[split_id][0], train=True),
TransformSubset(self, self.splits[split_id][1], train=False))

def replace_data(self, X_train: BaseDatasetType, X_test: Optional[BaseDatasetType]) -> 'BaseDataset':
def replace_data(self, X_train: BaseDatasetInputType,
X_test: Optional[BaseDatasetInputType]) -> 'BaseDataset':
"""
To speed up the training of small dataset, early pre-processing of the data
can be made on the fly by the pipeline.
Expand Down
Loading

0 comments on commit fae72a4

Please sign in to comment.