From 76bdde76c0f0d77f3c27c873532dfa7ca0e5695d Mon Sep 17 00:00:00 2001 From: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com> Date: Wed, 23 Jun 2021 17:37:17 +0200 Subject: [PATCH] [feat] Add flexible step-wise LR scheduler with minimum changes (#256) * [doc] Add the bug fix information for ipython user * [fix] Delete the gcc and gnn install because I could check we can build without them * [feat] Add flexible step-wise LR scheduler with minimum changes Since we would like to merge this feature promptly, I cut this new branch from the branch hot-fix-adapt... and narrowed down the scope of this PR. The subsequent PR addresses the issues, epspecially the format and mypy typing. * [fix] Fix flake8 issue and remove unneeded changes because of mis-branching * [test] Add the tests for the new features * [fix] Fix flake8 issues and torch.tensor to torch.Tensor in typing check The intention behind the change from torch.tensor to torch.Tensor is that since I got an error NoModuleFound `import torch.tensor`. Plus, the torch.tensor is not a TensorType class, but torch.Tensor is. Therefore, I changed the torch.tensor to torch.Tensor. * [feat] Be able to add the step_unit to ConfigSpace * [fix] Fix pytest issues by adding batch-wise update to incumbent Since the previous version always use batch-wise update, I added the step_unit = batch and then avoid the errors I got from pytest. * [fix] Add step_unit info in the greedy portfolio json file Since the latest version only supports the batch-wise update, I just inserted step_unit == "batch" to be able to run greedy portfolio selection. * [refactor] Rebase to the latest development and add overridden functions in base_scheduler * [fix] Fix flake8 and mypy issues * [fix] Fix flake8 issues * [test] Add the test for the train step and the lr scheduler check * [refactor] Change the name to * [fix] Fix flake8 issues * [fix] Disable the step_interval option from the hyperparameter settings * [fix] Change the default step_interval to Epoch-wise * [fix] Fix the step timing for ReduceLROnPlateau and fix flake8 issues * [fix] Add the after-validation option to StepIntervalUnit for ReduceLROnPlateau * [fix] Fix flake8 issues * [fix] Fix loss value for epoch-wise scheduler update * [fix] Delete type check of step_interval and make it property Since the step_interval should not be modified from outside, I made it a property of the base_scheduler class. Furthermore, since we do not have to check the type of step_interval except the initialization, I deleted the type check from prepare method. * [fix] Fix a mypy issue * [fix] Fix a mypy issue * [fix] Fix mypy issues * [fix] Fix mypy issues * [feedback] Address the Ravin's suggestions --- .../setup/lr_scheduler/CosineAnnealingLR.py | 8 +- .../CosineAnnealingWarmRestarts.py | 16 ++- .../components/setup/lr_scheduler/CyclicLR.py | 6 +- .../setup/lr_scheduler/ExponentialLR.py | 8 +- .../setup/lr_scheduler/NoScheduler.py | 7 +- .../setup/lr_scheduler/ReduceLROnPlateau.py | 14 +- .../components/setup/lr_scheduler/StepLR.py | 8 +- .../components/setup/lr_scheduler/__init__.py | 1 + .../setup/lr_scheduler/base_scheduler.py | 26 +++- .../setup/lr_scheduler/constants.py | 17 +++ .../training/trainer/MixUpTrainer.py | 19 +-- .../training/trainer/StandardTrainer.py | 17 ++- .../components/training/trainer/__init__.py | 3 +- .../training/trainer/base_trainer.py | 74 +++++++---- requirements.txt | 2 +- .../components/setup/test_setup.py | 52 +++++--- .../components/training/test_training.py | 124 ++++++++++++++---- 17 files changed, 287 insertions(+), 115 deletions(-) create mode 100644 autoPyTorch/pipeline/components/setup/lr_scheduler/constants.py diff --git a/autoPyTorch/pipeline/components/setup/lr_scheduler/CosineAnnealingLR.py b/autoPyTorch/pipeline/components/setup/lr_scheduler/CosineAnnealingLR.py index 0506cd046..12040178a 100644 --- a/autoPyTorch/pipeline/components/setup/lr_scheduler/CosineAnnealingLR.py +++ b/autoPyTorch/pipeline/components/setup/lr_scheduler/CosineAnnealingLR.py @@ -8,10 +8,10 @@ import numpy as np import torch.optim.lr_scheduler -from torch.optim.lr_scheduler import _LRScheduler from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent +from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter @@ -26,13 +26,13 @@ class CosineAnnealingLR(BaseLRComponent): def __init__( self, T_max: int, + step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.epoch, random_state: Optional[np.random.RandomState] = None ): - super().__init__() + super().__init__(step_interval) self.T_max = T_max self.random_state = random_state - self.scheduler = None # type: Optional[_LRScheduler] def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent: """ @@ -71,6 +71,8 @@ def get_hyperparameter_search_space( default_value=200, ) ) -> ConfigurationSpace: + cs = ConfigurationSpace() add_hyperparameter(cs, T_max, UniformIntegerHyperparameter) + return cs diff --git a/autoPyTorch/pipeline/components/setup/lr_scheduler/CosineAnnealingWarmRestarts.py b/autoPyTorch/pipeline/components/setup/lr_scheduler/CosineAnnealingWarmRestarts.py index aff52c426..61ecd0bc1 100644 --- a/autoPyTorch/pipeline/components/setup/lr_scheduler/CosineAnnealingWarmRestarts.py +++ b/autoPyTorch/pipeline/components/setup/lr_scheduler/CosineAnnealingWarmRestarts.py @@ -1,15 +1,18 @@ from typing import Any, Dict, Optional, Union from ConfigSpace.configuration_space import ConfigurationSpace -from ConfigSpace.hyperparameters import UniformFloatHyperparameter, UniformIntegerHyperparameter +from ConfigSpace.hyperparameters import ( + UniformFloatHyperparameter, + UniformIntegerHyperparameter +) import numpy as np import torch.optim.lr_scheduler -from torch.optim.lr_scheduler import _LRScheduler from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent +from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter @@ -30,13 +33,13 @@ def __init__( self, T_0: int, T_mult: int, - random_state: Optional[np.random.RandomState] = None + step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.epoch, + random_state: Optional[np.random.RandomState] = None, ): - super().__init__() + super().__init__(step_interval) self.T_0 = T_0 self.T_mult = T_mult self.random_state = random_state - self.scheduler = None # type: Optional[_LRScheduler] def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent: """ @@ -78,8 +81,9 @@ def get_hyperparameter_search_space( T_mult: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter='T_mult', value_range=(1.0, 2.0), default_value=1.0, - ), + ) ) -> ConfigurationSpace: + cs = ConfigurationSpace() add_hyperparameter(cs, T_0, UniformIntegerHyperparameter) add_hyperparameter(cs, T_mult, UniformFloatHyperparameter) diff --git a/autoPyTorch/pipeline/components/setup/lr_scheduler/CyclicLR.py b/autoPyTorch/pipeline/components/setup/lr_scheduler/CyclicLR.py index d3effbfa9..d26d3d495 100644 --- a/autoPyTorch/pipeline/components/setup/lr_scheduler/CyclicLR.py +++ b/autoPyTorch/pipeline/components/setup/lr_scheduler/CyclicLR.py @@ -10,10 +10,10 @@ import numpy as np import torch.optim.lr_scheduler -from torch.optim.lr_scheduler import _LRScheduler from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent +from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter @@ -39,16 +39,16 @@ def __init__( base_lr: float, mode: str, step_size_up: int, + step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.epoch, max_lr: float = 0.1, random_state: Optional[np.random.RandomState] = None ): - super().__init__() + super().__init__(step_interval) self.base_lr = base_lr self.mode = mode self.max_lr = max_lr self.step_size_up = step_size_up self.random_state = random_state - self.scheduler = None # type: Optional[_LRScheduler] def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent: """ diff --git a/autoPyTorch/pipeline/components/setup/lr_scheduler/ExponentialLR.py b/autoPyTorch/pipeline/components/setup/lr_scheduler/ExponentialLR.py index 036e8b302..dc57cfc1e 100644 --- a/autoPyTorch/pipeline/components/setup/lr_scheduler/ExponentialLR.py +++ b/autoPyTorch/pipeline/components/setup/lr_scheduler/ExponentialLR.py @@ -8,10 +8,10 @@ import numpy as np import torch.optim.lr_scheduler -from torch.optim.lr_scheduler import _LRScheduler from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent +from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter @@ -27,13 +27,13 @@ class ExponentialLR(BaseLRComponent): def __init__( self, gamma: float, + step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.epoch, random_state: Optional[np.random.RandomState] = None ): - super().__init__() + super().__init__(step_interval) self.gamma = gamma self.random_state = random_state - self.scheduler = None # type: Optional[_LRScheduler] def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent: """ @@ -72,6 +72,8 @@ def get_hyperparameter_search_space( default_value=0.9, ) ) -> ConfigurationSpace: + cs = ConfigurationSpace() add_hyperparameter(cs, gamma, UniformFloatHyperparameter) + return cs diff --git a/autoPyTorch/pipeline/components/setup/lr_scheduler/NoScheduler.py b/autoPyTorch/pipeline/components/setup/lr_scheduler/NoScheduler.py index bbe0a0c85..5a1f2e571 100644 --- a/autoPyTorch/pipeline/components/setup/lr_scheduler/NoScheduler.py +++ b/autoPyTorch/pipeline/components/setup/lr_scheduler/NoScheduler.py @@ -4,10 +4,9 @@ import numpy as np -from torch.optim.lr_scheduler import _LRScheduler - from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent +from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit class NoScheduler(BaseLRComponent): @@ -17,12 +16,12 @@ class NoScheduler(BaseLRComponent): """ def __init__( self, + step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.epoch, random_state: Optional[np.random.RandomState] = None ): - super().__init__() + super().__init__(step_interval) self.random_state = random_state - self.scheduler = None # type: Optional[_LRScheduler] def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent: """ diff --git a/autoPyTorch/pipeline/components/setup/lr_scheduler/ReduceLROnPlateau.py b/autoPyTorch/pipeline/components/setup/lr_scheduler/ReduceLROnPlateau.py index 2d00eb2ea..ae87bfdd2 100644 --- a/autoPyTorch/pipeline/components/setup/lr_scheduler/ReduceLROnPlateau.py +++ b/autoPyTorch/pipeline/components/setup/lr_scheduler/ReduceLROnPlateau.py @@ -10,10 +10,10 @@ import numpy as np import torch.optim.lr_scheduler -from torch.optim.lr_scheduler import _LRScheduler from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent +from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter @@ -31,7 +31,11 @@ class ReduceLROnPlateau(BaseLRComponent): factor (float): Factor by which the learning rate will be reduced. new_lr = lr * factor. patience (int): Number of epochs with no improvement after which learning rate will be reduced. + step_interval (str): step should be called after validation in the case of ReduceLROnPlateau random_state (Optional[np.random.RandomState]): random state + + Reference: + https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html#torch.optim.lr_scheduler.ReduceLROnPlateau """ def __init__( @@ -39,14 +43,14 @@ def __init__( mode: str, factor: float, patience: int, - random_state: Optional[np.random.RandomState] = None + step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.valid, + random_state: Optional[np.random.RandomState] = None, ): - super().__init__() + super().__init__(step_interval) self.mode = mode self.factor = factor self.patience = patience self.random_state = random_state - self.scheduler = None # type: Optional[_LRScheduler] def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent: """ @@ -93,7 +97,7 @@ def get_hyperparameter_search_space( factor: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter='factor', value_range=(0.01, 0.9), default_value=0.1, - ), + ) ) -> ConfigurationSpace: cs = ConfigurationSpace() diff --git a/autoPyTorch/pipeline/components/setup/lr_scheduler/StepLR.py b/autoPyTorch/pipeline/components/setup/lr_scheduler/StepLR.py index fe01c2a3b..1917e61ae 100644 --- a/autoPyTorch/pipeline/components/setup/lr_scheduler/StepLR.py +++ b/autoPyTorch/pipeline/components/setup/lr_scheduler/StepLR.py @@ -9,10 +9,10 @@ import numpy as np import torch.optim.lr_scheduler -from torch.optim.lr_scheduler import _LRScheduler from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent +from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter @@ -32,13 +32,13 @@ def __init__( self, step_size: int, gamma: float, + step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.epoch, random_state: Optional[np.random.RandomState] = None ): - super().__init__() + super().__init__(step_interval) self.gamma = gamma self.step_size = step_size self.random_state = random_state - self.scheduler = None # type: Optional[_LRScheduler] def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent: """ @@ -80,7 +80,7 @@ def get_hyperparameter_search_space( step_size: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter='step_size', value_range=(1, 10), default_value=5, - ), + ) ) -> ConfigurationSpace: cs = ConfigurationSpace() diff --git a/autoPyTorch/pipeline/components/setup/lr_scheduler/__init__.py b/autoPyTorch/pipeline/components/setup/lr_scheduler/__init__.py index e42b41874..ceb91cd0f 100644 --- a/autoPyTorch/pipeline/components/setup/lr_scheduler/__init__.py +++ b/autoPyTorch/pipeline/components/setup/lr_scheduler/__init__.py @@ -16,6 +16,7 @@ ) from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent + directory = os.path.split(__file__)[0] _schedulers = find_components(__package__, directory, diff --git a/autoPyTorch/pipeline/components/setup/lr_scheduler/base_scheduler.py b/autoPyTorch/pipeline/components/setup/lr_scheduler/base_scheduler.py index c541507b8..24368a8f0 100644 --- a/autoPyTorch/pipeline/components/setup/lr_scheduler/base_scheduler.py +++ b/autoPyTorch/pipeline/components/setup/lr_scheduler/base_scheduler.py @@ -1,9 +1,10 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from autoPyTorch.pipeline.components.setup.base_setup import autoPyTorchSetupComponent +from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit, StepIntervalUnitChoices from autoPyTorch.utils.common import FitRequirement @@ -11,13 +12,28 @@ class BaseLRComponent(autoPyTorchSetupComponent): """Provide an abstract interface for schedulers in Auto-Pytorch""" - def __init__(self) -> None: + def __init__(self, step_interval: Union[str, StepIntervalUnit]): super().__init__() self.scheduler = None # type: Optional[_LRScheduler] + self._step_interval: StepIntervalUnit + + if isinstance(step_interval, str): + if step_interval not in StepIntervalUnitChoices: + raise ValueError('step_interval must be either {}, but got {}.'.format( + StepIntervalUnitChoices, + step_interval + )) + self._step_interval = getattr(StepIntervalUnit, step_interval) + else: + self._step_interval = step_interval self.add_fit_requirements([ FitRequirement('optimizer', (Optimizer,), user_defined=False, dataset_property=False)]) + @property + def step_interval(self) -> StepIntervalUnit: + return self._step_interval + def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: """ Adds the scheduler into the fit dictionary 'X' and returns it. @@ -26,7 +42,11 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: Returns: (Dict[str, Any]): the updated 'X' dictionary """ - X.update({'lr_scheduler': self.scheduler}) + + X.update( + lr_scheduler=self.scheduler, + step_interval=self.step_interval + ) return X def get_scheduler(self) -> _LRScheduler: diff --git a/autoPyTorch/pipeline/components/setup/lr_scheduler/constants.py b/autoPyTorch/pipeline/components/setup/lr_scheduler/constants.py new file mode 100644 index 000000000..2e5895632 --- /dev/null +++ b/autoPyTorch/pipeline/components/setup/lr_scheduler/constants.py @@ -0,0 +1,17 @@ +from enum import Enum + + +class StepIntervalUnit(Enum): + """ + By which interval we perform the step for learning rate schedulers. + Attributes: + batch (str): We update every batch evaluation + epoch (str): We update every epoch + valid (str): We update every validation + """ + batch = 'batch' + epoch = 'epoch' + valid = 'valid' + + +StepIntervalUnitChoices = [step_interval.name for step_interval in StepIntervalUnit] diff --git a/autoPyTorch/pipeline/components/training/trainer/MixUpTrainer.py b/autoPyTorch/pipeline/components/training/trainer/MixUpTrainer.py index 3978fdab0..91e177b08 100644 --- a/autoPyTorch/pipeline/components/training/trainer/MixUpTrainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/MixUpTrainer.py @@ -30,8 +30,8 @@ def __init__(self, alpha: float, weighted_loss: bool = False, self.weighted_loss = weighted_loss self.alpha = alpha - def data_preparation(self, X: np.ndarray, y: np.ndarray, - ) -> typing.Tuple[np.ndarray, typing.Dict[str, np.ndarray]]: + def data_preparation(self, X: torch.Tensor, y: torch.Tensor, + ) -> typing.Tuple[torch.Tensor, typing.Dict[str, np.ndarray]]: """ Depending on the trainer choice, data fed to the network might be pre-processed on a different way. That is, in standard training we provide the data to the @@ -39,22 +39,25 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray, alter the data. Args: - X (np.ndarray): The batch training features - y (np.ndarray): The batch training labels + X (torch.Tensor): The batch training features + y (torch.Tensor): The batch training labels Returns: - np.ndarray: that processes data + torch.Tensor: that processes data typing.Dict[str, np.ndarray]: arguments to the criterion function + TODO: Fix this typing. It is not np.ndarray. """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + lam = self.random_state.beta(self.alpha, self.alpha) if self.alpha > 0. else 1. - batch_size = X.size()[0] - index = torch.randperm(batch_size).cuda() if X.is_cuda else torch.randperm(batch_size) + batch_size = X.shape[0] + index = torch.randperm(batch_size).to(device) mixed_x = lam * X + (1 - lam) * X[index, :] y_a, y_b = y, y[index] return mixed_x, {'y_a': y_a, 'y_b': y_b, 'lam': lam} - def criterion_preparation(self, y_a: np.ndarray, y_b: np.ndarray = None, lam: float = 1.0 + def criterion_preparation(self, y_a: torch.Tensor, y_b: torch.Tensor = None, lam: float = 1.0 ) -> typing.Callable: return lambda criterion, pred: lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) diff --git a/autoPyTorch/pipeline/components/training/trainer/StandardTrainer.py b/autoPyTorch/pipeline/components/training/trainer/StandardTrainer.py index 2df27b3b7..c89e17bc8 100644 --- a/autoPyTorch/pipeline/components/training/trainer/StandardTrainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/StandardTrainer.py @@ -5,6 +5,8 @@ import numpy as np +import torch + from autoPyTorch.constants import CLASSIFICATION_TASKS, STRING_TO_TASK_TYPES from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType from autoPyTorch.pipeline.components.training.trainer.base_trainer import BaseTrainerComponent @@ -24,8 +26,8 @@ def __init__(self, weighted_loss: bool = False, super().__init__(random_state=random_state) self.weighted_loss = weighted_loss - def data_preparation(self, X: np.ndarray, y: np.ndarray, - ) -> typing.Tuple[np.ndarray, typing.Dict[str, np.ndarray]]: + def data_preparation(self, X: torch.Tensor, y: torch.Tensor, + ) -> typing.Tuple[torch.Tensor, typing.Dict[str, np.ndarray]]: """ Depending on the trainer choice, data fed to the network might be pre-processed on a different way. That is, in standard training we provide the data to the @@ -33,16 +35,17 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray, alter the data. Args: - X (np.ndarray): The batch training features - y (np.ndarray): The batch training labels + X (torch.Tensor): The batch training features + y (torch.Tensor): The batch training labels Returns: - np.ndarray: that processes data + torch.Tensor: that processes data typing.Dict[str, np.ndarray]: arguments to the criterion function + TODO: Fix this typing. It is not np.ndarray. """ return X, {'y_a': y} - def criterion_preparation(self, y_a: np.ndarray, y_b: np.ndarray = None, lam: float = 1.0 + def criterion_preparation(self, y_a: torch.Tensor, y_b: torch.Tensor = None, lam: float = 1.0 ) -> typing.Callable: return lambda criterion, pred: criterion(pred, y_a) @@ -51,7 +54,7 @@ def get_properties(dataset_properties: typing.Optional[typing.Dict[str, BaseData ) -> typing.Dict[str, typing.Union[str, bool]]: return { 'shortname': 'StandardTrainer', - 'name': 'StandardTrainer', + 'name': 'Standard Trainer', } @staticmethod diff --git a/autoPyTorch/pipeline/components/training/trainer/__init__.py b/autoPyTorch/pipeline/components/training/trainer/__init__.py index e9c2e7b46..b4221605a 100755 --- a/autoPyTorch/pipeline/components/training/trainer/__init__.py +++ b/autoPyTorch/pipeline/components/training/trainer/__init__.py @@ -266,7 +266,8 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic metrics_during_training=X['metrics_during_training'], scheduler=X['lr_scheduler'], task_type=STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']], - labels=X['y_train'][X['backend'].load_datamanager().splits[X['split_id']][0]] + labels=X['y_train'][X['backend'].load_datamanager().splits[X['split_id']][0]], + step_interval=X['step_interval'] ) total_parameter_count, trainable_parameter_count = self.count_parameters(X['network']) self.run_summary = RunSummary( diff --git a/autoPyTorch/pipeline/components/training/trainer/base_trainer.py b/autoPyTorch/pipeline/components/training/trainer/base_trainer.py index bae664f5c..9b7b79ac8 100644 --- a/autoPyTorch/pipeline/components/training/trainer/base_trainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/base_trainer.py @@ -14,6 +14,7 @@ from autoPyTorch.constants import REGRESSION_TASKS +from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit from autoPyTorch.pipeline.components.training.base_training import autoPyTorchTrainingComponent from autoPyTorch.pipeline.components.training.metrics.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS from autoPyTorch.pipeline.components.training.metrics.utils import calculate_score @@ -203,7 +204,8 @@ def prepare( metrics_during_training: bool, scheduler: _LRScheduler, task_type: int, - labels: Union[np.ndarray, torch.Tensor, pd.DataFrame] + labels: Union[np.ndarray, torch.Tensor, pd.DataFrame], + step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.batch ) -> None: # Save the device to be used @@ -233,6 +235,7 @@ def prepare( # Scheduler self.scheduler = scheduler + self.step_interval = step_interval # task type (used for calculating metrics) self.task_type = task_type @@ -254,11 +257,29 @@ def on_epoch_end(self, X: Dict[str, Any], epoch: int) -> bool: """ return False + def _scheduler_step( + self, + step_interval: StepIntervalUnit, + loss: Optional[float] = None + ) -> None: + + if self.step_interval != step_interval: + return + + if not self.scheduler: # skip if no scheduler defined + return + + try: + """ Some schedulers might use metrics """ + self.scheduler.step(metrics=loss) + except TypeError: + self.scheduler.step() + def train_epoch(self, train_loader: torch.utils.data.DataLoader, epoch: int, writer: Optional[SummaryWriter], ) -> Tuple[float, Dict[str, float]]: - ''' - Trains the model for a single epoch. + """ + Train the model for a single epoch. Args: train_loader (torch.utils.data.DataLoader): generator of features/label @@ -267,7 +288,7 @@ def train_epoch(self, train_loader: torch.utils.data.DataLoader, epoch: int, Returns: float: training loss Dict[str, float]: scores for each desired metric - ''' + """ loss_sum = 0.0 N = 0 @@ -296,6 +317,8 @@ def train_epoch(self, train_loader: torch.utils.data.DataLoader, epoch: int, epoch * len(train_loader) + step, ) + self._scheduler_step(step_interval=StepIntervalUnit.epoch, loss=loss_sum / N) + if self.metrics_during_training: return loss_sum / N, self.compute_metrics(outputs_data, targets_data) else: @@ -311,13 +334,13 @@ def cast_targets(self, targets: torch.Tensor) -> torch.Tensor: targets = targets.long().to(self.device) return targets - def train_step(self, data: np.ndarray, targets: np.ndarray) -> Tuple[float, torch.Tensor]: + def train_step(self, data: torch.Tensor, targets: torch.Tensor) -> Tuple[float, torch.Tensor]: """ Allows to train 1 step of gradient descent, given a batch of train/labels Args: - data (np.ndarray): input features to the network - targets (np.ndarray): ground truth to calculate loss + data (torch.Tensor): input features to the network + targets (torch.Tensor): ground truth to calculate loss Returns: torch.Tensor: The predictions of the network @@ -336,19 +359,15 @@ def train_step(self, data: np.ndarray, targets: np.ndarray) -> Tuple[float, torc loss = loss_func(self.criterion, outputs) loss.backward() self.optimizer.step() - if self.scheduler: - if 'ReduceLROnPlateau' in self.scheduler.__class__.__name__: - self.scheduler.step(loss) - else: - self.scheduler.step() + self._scheduler_step(step_interval=StepIntervalUnit.batch, loss=loss.item()) return loss.item(), outputs def evaluate(self, test_loader: torch.utils.data.DataLoader, epoch: int, writer: Optional[SummaryWriter], ) -> Tuple[float, Dict[str, float]]: - ''' - Evaluates the model in both metrics and criterion + """ + Evaluate the model in both metrics and criterion Args: test_loader (torch.utils.data.DataLoader): generator of features/label @@ -357,7 +376,7 @@ def evaluate(self, test_loader: torch.utils.data.DataLoader, epoch: int, Returns: float: test loss Dict[str, float]: scores for each desired metric - ''' + """ self.model.eval() loss_sum = 0.0 @@ -388,10 +407,12 @@ def evaluate(self, test_loader: torch.utils.data.DataLoader, epoch: int, epoch * len(test_loader) + step, ) + self._scheduler_step(step_interval=StepIntervalUnit.valid, loss=loss_sum / N) + self.model.train() return loss_sum / N, self.compute_metrics(outputs_data, targets_data) - def compute_metrics(self, outputs_data: np.ndarray, targets_data: np.ndarray + def compute_metrics(self, outputs_data: List[torch.Tensor], targets_data: List[torch.Tensor] ) -> Dict[str, float]: # TODO: change once Ravin Provides the PR outputs_data = torch.cat(outputs_data, dim=0).numpy() @@ -399,7 +420,7 @@ def compute_metrics(self, outputs_data: np.ndarray, targets_data: np.ndarray return calculate_score(targets_data, outputs_data, self.task_type, self.metrics) def get_class_weights(self, criterion: Type[torch.nn.Module], labels: Union[np.ndarray, torch.Tensor, pd.DataFrame] - ) -> Dict[str, np.ndarray]: + ) -> Dict[str, torch.Tensor]: strategy = get_loss_weight_strategy(criterion) weights = strategy(y=labels) weights = torch.from_numpy(weights) @@ -409,8 +430,8 @@ def get_class_weights(self, criterion: Type[torch.nn.Module], labels: Union[np.n else: return {'weight': weights} - def data_preparation(self, X: np.ndarray, y: np.ndarray, - ) -> Tuple[np.ndarray, Dict[str, np.ndarray]]: + def data_preparation(self, X: torch.Tensor, y: torch.Tensor, + ) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]: """ Depending on the trainer choice, data fed to the network might be pre-processed on a different way. That is, in standard training we provide the data to the @@ -418,16 +439,17 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray, alter the data. Args: - X (np.ndarray): The batch training features - y (np.ndarray): The batch training labels + X (torch.Tensor): The batch training features + y (torch.Tensor): The batch training labels Returns: - np.ndarray: that processes data + torch.Tensor: that processes data Dict[str, np.ndarray]: arguments to the criterion function + TODO: Fix this typing. It is not np.ndarray. """ - raise NotImplementedError() + raise NotImplementedError - def criterion_preparation(self, y_a: np.ndarray, y_b: np.ndarray = None, lam: float = 1.0 + def criterion_preparation(self, y_a: torch.Tensor, y_b: torch.Tensor = None, lam: float = 1.0 ) -> Callable: # type: ignore """ Depending on the trainer choice, the criterion is not directly applied to the @@ -439,6 +461,6 @@ def criterion_preparation(self, y_a: np.ndarray, y_b: np.ndarray = None, lam: fl criterion calculation Returns: - Callable: a lambda that contains the new criterion calculation recipe + Callable: a lambda function that contains the new criterion calculation recipe """ - raise NotImplementedError() + raise NotImplementedError diff --git a/requirements.txt b/requirements.txt index c79104461..a2f23958f 100755 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ torchvision tensorboard scikit-learn>=0.24.0,<0.25.0 numpy -scipy +scipy>=0.14.1,<1.7.0 lockfile imgaug>=0.4.0 ConfigSpace>=0.4.14,<0.5 diff --git a/test/test_pipeline/components/setup/test_setup.py b/test/test_pipeline/components/setup/test_setup.py index 658923512..5d65ac14a 100644 --- a/test/test_pipeline/components/setup/test_setup.py +++ b/test/test_pipeline/components/setup/test_setup.py @@ -19,7 +19,11 @@ from autoPyTorch.pipeline.components.base_component import ThirdPartyComponents from autoPyTorch.pipeline.components.setup.lr_scheduler import ( BaseLRComponent, - SchedulerChoice + SchedulerChoice, +) +from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import ( + StepIntervalUnit, + StepIntervalUnitChoices ) from autoPyTorch.pipeline.components.setup.network_backbone import NetworkBackboneChoice from autoPyTorch.pipeline.components.setup.network_backbone.ResNetBackbone import ResBlock @@ -43,14 +47,15 @@ class DummyLR(BaseLRComponent): - def __init__(self, random_state=None): - pass + def __init__(self, step_interval: StepIntervalUnit, random_state=None): + super().__init__(step_interval=step_interval) @staticmethod def get_hyperparameter_search_space(dataset_properties=None): cs = ConfigurationSpace() return cs + @staticmethod def get_properties(dataset_properties=None): return { 'shortname': 'Dummy', @@ -67,6 +72,7 @@ def get_hyperparameter_search_space(dataset_properties=None): cs = ConfigurationSpace() return cs + @staticmethod def get_properties(dataset_properties=None): return { 'shortname': 'Dummy', @@ -83,6 +89,7 @@ def get_hyperparameter_search_space(dataset_properties=None): cs = ConfigurationSpace() return cs + @staticmethod def get_properties(dataset_properties=None): return { 'shortname': 'Dummy', @@ -149,7 +156,7 @@ def test_every_scheduler_is_valid(self): estimator_clone_params = estimator_clone.get_params() # Make sure all keys are copied properly - for k, v in estimator.get_params().items(): + for k in estimator.get_params().keys(): assert k in estimator_clone_params # Make sure the params getter of estimator are honored @@ -179,7 +186,7 @@ def test_get_set_config_space(self): # Whereas just one iteration will make sure the algorithm works, # doing five iterations increase the confidence. We will be able to # catch component specific crashes - for i in range(5): + for _ in range(5): config = cs.sample_configuration() config_dict = copy.deepcopy(config.get_dictionary()) scheduler_choice.set_hyperparameters(config) @@ -207,6 +214,21 @@ def test_scheduler_add(self): cs = SchedulerChoice(dataset_properties={}).get_hyperparameter_search_space() assert 'DummyLR' in str(cs) + def test_schduler_init(self): + for step_interval in StepIntervalUnitChoices: + DummyLR(step_interval=step_interval) + + for step_interval in ['Batch', 'foo']: + try: + DummyLR(step_interval=step_interval) + except ValueError: + pass + except Exception as e: + pytest.fail("The initialization of lr_scheduler raised an unexpected exception {}.".format(e)) + else: + pytest.fail("The initialization of lr_scheduler did not raise an Error " + "although the step_unit is invalid.") + class OptimizerTest: def test_every_optimizer_is_valid(self): @@ -233,7 +255,7 @@ def test_every_optimizer_is_valid(self): estimator_clone_params = estimator_clone.get_params() # Make sure all keys are copied properly - for k, v in estimator.get_params().items(): + for k in estimator.get_params().keys(): assert k in estimator_clone_params # Make sure the params getter of estimator are honored @@ -263,7 +285,7 @@ def test_get_set_config_space(self): # Whereas just one iteration will make sure the algorithm works, # doing five iterations increase the confidence. We will be able to # catch component specific crashes - for i in range(5): + for _ in range(5): config = cs.sample_configuration() config_dict = copy.deepcopy(config.get_dictionary()) optimizer_choice.set_hyperparameters(config) @@ -333,7 +355,7 @@ def test_dummy_forward_backward_pass(self, task_type_input_shape): cs = network_backbone_choice.get_hyperparameter_search_space(dataset_properties=dataset_properties) # test 10 random configurations - for i in range(10): + for _ in range(10): config = cs.sample_configuration() network_backbone_choice.set_hyperparameters(config) backbone = network_backbone_choice.choice.build_backbone(input_shape=input_shape) @@ -357,7 +379,7 @@ def test_every_backbone_is_valid(self): estimator_clone_params = estimator_clone.get_params() # Make sure all keys are copied properly - for k, v in estimator.get_params().items(): + for k in estimator.get_params().keys(): assert k in estimator_clone_params # Make sure the params getter of estimator are honored @@ -386,7 +408,7 @@ def test_get_set_config_space(self): # Whereas just one iteration will make sure the algorithm works, # doing five iterations increase the confidence. We will be able to # catch component specific crashes - for i in range(5): + for _ in range(5): config = cs.sample_configuration() config_dict = copy.deepcopy(config.get_dictionary()) network_backbone_choice.set_hyperparameters(config) @@ -500,7 +522,7 @@ def test_dummy_forward_backward_pass(self, task_type_input_output_shape): cs = network_head_choice.get_hyperparameter_search_space(dataset_properties=dataset_properties) # test 10 random configurations - for i in range(10): + for _ in range(10): config = cs.sample_configuration() network_head_choice.set_hyperparameters(config) head = network_head_choice.choice.build_head(input_shape=input_shape, @@ -534,7 +556,7 @@ def test_every_head_is_valid(self): estimator_clone_params = estimator_clone.get_params() # Make sure all keys are copied properly - for k, v in estimator.get_params().items(): + for k in estimator.get_params().keys(): assert k in estimator_clone_params # Make sure the params getter of estimator are honored @@ -563,7 +585,7 @@ def test_get_set_config_space(self): # Whereas just one iteration will make sure the algorithm works, # doing five iterations increase the confidence. We will be able to # catch component specific crashes - for i in range(5): + for _ in range(5): config = cs.sample_configuration() config_dict = copy.deepcopy(config.get_dictionary()) network_head_choice.set_hyperparameters(config) @@ -626,7 +648,7 @@ def test_every_network_initializer_is_valid(self): estimator_clone_params = estimator_clone.get_params() # Make sure all keys are copied properly - for k, v in estimator.get_params().items(): + for k in estimator.get_params().keys(): assert k in estimator_clone_params # Make sure the params getter of estimator are honored @@ -656,7 +678,7 @@ def test_get_set_config_space(self): # Whereas just one iteration will make sure the algorithm works, # doing five iterations increase the confidence. We will be able to # catch component specific crashes - for i in range(5): + for _ in range(5): config = cs.sample_configuration() config_dict = copy.deepcopy(config.get_dictionary()) network_initializer_choice.set_hyperparameters(config) diff --git a/test/test_pipeline/components/training/test_training.py b/test/test_pipeline/components/training/test_training.py index c55bd967c..3c0836d1f 100644 --- a/test/test_pipeline/components/training/test_training.py +++ b/test/test_pipeline/components/training/test_training.py @@ -26,7 +26,10 @@ StandardTrainer ) from autoPyTorch.pipeline.components.training.trainer.base_trainer import ( - BaseTrainerComponent, ) + BaseTrainerComponent, + BudgetTracker, + StepIntervalUnit +) sys.path.append(os.path.dirname(__file__)) from test.test_pipeline.components.training.base import BaseTraining # noqa (E402: module level import not at top of file) @@ -36,7 +39,7 @@ N_SAMPLES = 500 -class BaseDataLoaderTest(unittest.TestCase): +class TestBaseDataLoader(unittest.TestCase): def test_get_set_config_space(self): """ Makes sure that the configuration space of the base data loader @@ -49,7 +52,7 @@ def test_get_set_config_space(self): self.assertEqual(cs.get_hyperparameter('batch_size').default_value, 64) # Make sure we can properly set some random configs - for i in range(5): + for _ in range(5): config = cs.sample_configuration() config_dict = copy.deepcopy(config.get_dictionary()) loader.set_hyperparameters(config) @@ -140,9 +143,9 @@ def test_evaluate(self): loader, criterion, epochs, - logger) = self.prepare_trainer(N_SAMPLES, - BaseTrainerComponent(), - constants.TABULAR_CLASSIFICATION) + _) = self.prepare_trainer(N_SAMPLES, + BaseTrainerComponent(), + constants.TABULAR_CLASSIFICATION) prev_loss, prev_metrics = trainer.evaluate(loader, epoch=1, writer=None) assert 'accuracy' in prev_metrics @@ -161,8 +164,77 @@ def test_evaluate(self): assert metrics['accuracy'] > prev_metrics['accuracy'] assert metrics['accuracy'] > 0.5 + def test_scheduler_step(self): + trainer = BaseTrainerComponent() + model = torch.nn.Linear(1, 1) + + base_lr, factor = 1, 10 + optimizer = torch.optim.SGD(model.parameters(), lr=base_lr) + trainer.scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=list(range(1, 5)), + gamma=factor + ) -class StandardTrainerTest(BaseTraining): + target_lr = base_lr + for trainer_step_interval in StepIntervalUnit: + trainer.step_interval = trainer_step_interval + for step_interval in StepIntervalUnit: + if step_interval == trainer_step_interval: + target_lr *= factor + + trainer._scheduler_step(step_interval=step_interval) + lr = optimizer.param_groups[0]['lr'] + assert target_lr - 1e-6 <= lr <= target_lr + 1e-6 + + def test_train_step(self): + device = torch.device('cpu') + model = torch.nn.Linear(1, 1).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=1) + data, targets = torch.Tensor([1.]).to(device), torch.Tensor([1.]).to(device) + ms = [3, 5, 6] + params = { + 'metrics': [], + 'device': device, + 'task_type': constants.TABULAR_REGRESSION, + 'labels': torch.Tensor([]), + 'metrics_during_training': False, + 'budget_tracker': BudgetTracker(budget_type=''), + 'criterion': torch.nn.MSELoss, + 'optimizer': optimizer, + 'scheduler': torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=ms, gamma=2), + 'model': model, + 'step_interval': StepIntervalUnit.epoch + } + trainer = StandardTrainer() + trainer.prepare(**params) + + for _ in range(10): + trainer.train_step( + data=data, + targets=targets + ) + lr = optimizer.param_groups[0]['lr'] + assert lr == 1 + + params.update(step_interval=StepIntervalUnit.batch) + trainer = StandardTrainer() + trainer.prepare(**params) + + target_lr = 1 + for i in range(10): + trainer.train_step( + data=data, + targets=targets + ) + if i + 1 in ms: + target_lr *= 2 + + lr = optimizer.param_groups[0]['lr'] + assert lr == target_lr + + +class TestStandardTrainer(BaseTraining): def test_regression_epoch_training(self, n_samples): (trainer, _, @@ -170,16 +242,16 @@ def test_regression_epoch_training(self, n_samples): loader, _, epochs, - logger) = self.prepare_trainer(n_samples, - StandardTrainer(), - constants.TABULAR_REGRESSION, - OVERFIT_EPOCHS) + _) = self.prepare_trainer(n_samples, + StandardTrainer(), + constants.TABULAR_REGRESSION, + OVERFIT_EPOCHS) # Train the model counter = 0 r2 = 0 while r2 < 0.7: - loss, metrics = trainer.train_epoch(loader, epoch=1, writer=None) + _, metrics = trainer.train_epoch(loader, epoch=1, writer=None) counter += 1 r2 = metrics['r2'] @@ -193,16 +265,16 @@ def test_classification_epoch_training(self, n_samples): loader, _, epochs, - logger) = self.prepare_trainer(n_samples, - StandardTrainer(), - constants.TABULAR_CLASSIFICATION, - OVERFIT_EPOCHS) + _) = self.prepare_trainer(n_samples, + StandardTrainer(), + constants.TABULAR_CLASSIFICATION, + OVERFIT_EPOCHS) # Train the model counter = 0 accuracy = 0 while accuracy < 0.7: - loss, metrics = trainer.train_epoch(loader, epoch=1, writer=None) + _, metrics = trainer.train_epoch(loader, epoch=1, writer=None) counter += 1 accuracy = metrics['accuracy'] @@ -210,7 +282,7 @@ def test_classification_epoch_training(self, n_samples): pytest.fail(f"Could not overfit a dummy classification under {epochs} epochs") -class MixUpTrainerTest(BaseTraining): +class TestMixUpTrainer(BaseTraining): def test_classification_epoch_training(self, n_samples): (trainer, _, @@ -218,16 +290,16 @@ def test_classification_epoch_training(self, n_samples): loader, _, epochs, - logger) = self.prepare_trainer(n_samples, - MixUpTrainer(alpha=0.5), - constants.TABULAR_CLASSIFICATION, - OVERFIT_EPOCHS) + _) = self.prepare_trainer(n_samples, + MixUpTrainer(alpha=0.5), + constants.TABULAR_CLASSIFICATION, + OVERFIT_EPOCHS) # Train the model counter = 0 accuracy = 0 while accuracy < 0.7: - loss, metrics = trainer.train_epoch(loader, epoch=1, writer=None) + _, metrics = trainer.train_epoch(loader, epoch=1, writer=None) counter += 1 accuracy = metrics['accuracy'] @@ -235,7 +307,7 @@ def test_classification_epoch_training(self, n_samples): pytest.fail(f"Could not overfit a dummy classification under {epochs} epochs") -class TrainerTest(unittest.TestCase): +class TestTrainer(unittest.TestCase): def test_every_trainer_is_valid(self): """ Makes sure that every trainer is a valid estimator. @@ -260,7 +332,7 @@ def test_every_trainer_is_valid(self): estimator_clone_params = estimator_clone.get_params() # Make sure all keys are copied properly - for k, v in estimator.get_params().items(): + for k in estimator.get_params().keys(): self.assertIn(k, estimator_clone_params) # Make sure the params getter of estimator are honored @@ -292,7 +364,7 @@ def test_get_set_config_space(self): # Whereas just one iteration will make sure the algorithm works, # doing five iterations increase the confidence. We will be able to # catch component specific crashes - for i in range(5): + for _ in range(5): config = cs.sample_configuration() config_dict = copy.deepcopy(config.get_dictionary()) trainer_choice.set_hyperparameters(config)