Skip to content

Commit

Permalink
[feat] Add flexible step-wise LR scheduler with minimum changes (auto…
Browse files Browse the repository at this point in the history
…ml#256)

* [doc] Add the bug fix information for ipython user

* [fix] Delete the gcc and gnn install because I could check we can build without them

* [feat] Add flexible step-wise LR scheduler with minimum changes

Since we would like to merge this feature promptly, I cut this new branch
from the branch hot-fix-adapt... and narrowed down the scope of this PR.
The subsequent PR addresses the issues, epspecially the format and mypy
typing.

* [fix] Fix flake8 issue and remove unneeded changes because of mis-branching

* [test] Add the tests for the new features

* [fix] Fix flake8 issues and torch.tensor to torch.Tensor in typing check

The intention behind the change from torch.tensor to torch.Tensor is that
since I got an error NoModuleFound `import torch.tensor`.
Plus, the torch.tensor is not a TensorType class, but torch.Tensor is.
Therefore, I changed the torch.tensor to torch.Tensor.

* [feat] Be able to add the step_unit to ConfigSpace

* [fix] Fix pytest issues by adding batch-wise update to incumbent

Since the previous version always use batch-wise update, I added the
step_unit = batch and then avoid the errors I got from pytest.

* [fix] Add step_unit info in the greedy portfolio json file

Since the latest version only supports the batch-wise update,
I just inserted step_unit == "batch" to be able to run greedy portfolio
selection.

* [refactor] Rebase to the latest development and add overridden functions in base_scheduler

* [fix] Fix flake8 and mypy issues

* [fix] Fix flake8 issues

* [test] Add the test for the train step and the lr scheduler check

* [refactor] Change the name  to

* [fix] Fix flake8 issues

* [fix] Disable the step_interval option from the hyperparameter settings

* [fix] Change the default step_interval to Epoch-wise

* [fix] Fix the step timing for ReduceLROnPlateau and fix flake8 issues

* [fix] Add the after-validation option to StepIntervalUnit for ReduceLROnPlateau

* [fix] Fix flake8 issues

* [fix] Fix loss value for epoch-wise scheduler update

* [fix] Delete type check of step_interval and make it property

Since the step_interval should not be modified from outside,
I made it a property of the base_scheduler class.
Furthermore, since we do not have to check the type of step_interval
except the initialization, I deleted the type check from prepare method.

* [fix] Fix a mypy issue

* [fix] Fix a mypy issue

* [fix] Fix mypy issues

* [fix] Fix mypy issues

* [feedback] Address the Ravin's suggestions
  • Loading branch information
nabenabe0928 authored Jun 23, 2021
1 parent 999f3c3 commit 76bdde7
Show file tree
Hide file tree
Showing 17 changed files with 287 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import numpy as np

import torch.optim.lr_scheduler
from torch.optim.lr_scheduler import _LRScheduler

from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent
from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter


Expand All @@ -26,13 +26,13 @@ class CosineAnnealingLR(BaseLRComponent):
def __init__(
self,
T_max: int,
step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.epoch,
random_state: Optional[np.random.RandomState] = None
):

super().__init__()
super().__init__(step_interval)
self.T_max = T_max
self.random_state = random_state
self.scheduler = None # type: Optional[_LRScheduler]

def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent:
"""
Expand Down Expand Up @@ -71,6 +71,8 @@ def get_hyperparameter_search_space(
default_value=200,
)
) -> ConfigurationSpace:

cs = ConfigurationSpace()
add_hyperparameter(cs, T_max, UniformIntegerHyperparameter)

return cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from typing import Any, Dict, Optional, Union

from ConfigSpace.configuration_space import ConfigurationSpace
from ConfigSpace.hyperparameters import UniformFloatHyperparameter, UniformIntegerHyperparameter
from ConfigSpace.hyperparameters import (
UniformFloatHyperparameter,
UniformIntegerHyperparameter
)

import numpy as np

import torch.optim.lr_scheduler
from torch.optim.lr_scheduler import _LRScheduler

from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent
from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter


Expand All @@ -30,13 +33,13 @@ def __init__(
self,
T_0: int,
T_mult: int,
random_state: Optional[np.random.RandomState] = None
step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.epoch,
random_state: Optional[np.random.RandomState] = None,
):
super().__init__()
super().__init__(step_interval)
self.T_0 = T_0
self.T_mult = T_mult
self.random_state = random_state
self.scheduler = None # type: Optional[_LRScheduler]

def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent:
"""
Expand Down Expand Up @@ -78,8 +81,9 @@ def get_hyperparameter_search_space(
T_mult: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter='T_mult',
value_range=(1.0, 2.0),
default_value=1.0,
),
)
) -> ConfigurationSpace:

cs = ConfigurationSpace()
add_hyperparameter(cs, T_0, UniformIntegerHyperparameter)
add_hyperparameter(cs, T_mult, UniformFloatHyperparameter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
import numpy as np

import torch.optim.lr_scheduler
from torch.optim.lr_scheduler import _LRScheduler

from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent
from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter


Expand All @@ -39,16 +39,16 @@ def __init__(
base_lr: float,
mode: str,
step_size_up: int,
step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.epoch,
max_lr: float = 0.1,
random_state: Optional[np.random.RandomState] = None
):
super().__init__()
super().__init__(step_interval)
self.base_lr = base_lr
self.mode = mode
self.max_lr = max_lr
self.step_size_up = step_size_up
self.random_state = random_state
self.scheduler = None # type: Optional[_LRScheduler]

def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import numpy as np

import torch.optim.lr_scheduler
from torch.optim.lr_scheduler import _LRScheduler

from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent
from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter


Expand All @@ -27,13 +27,13 @@ class ExponentialLR(BaseLRComponent):
def __init__(
self,
gamma: float,
step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.epoch,
random_state: Optional[np.random.RandomState] = None
):

super().__init__()
super().__init__(step_interval)
self.gamma = gamma
self.random_state = random_state
self.scheduler = None # type: Optional[_LRScheduler]

def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent:
"""
Expand Down Expand Up @@ -72,6 +72,8 @@ def get_hyperparameter_search_space(
default_value=0.9,
)
) -> ConfigurationSpace:

cs = ConfigurationSpace()
add_hyperparameter(cs, gamma, UniformFloatHyperparameter)

return cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@

import numpy as np

from torch.optim.lr_scheduler import _LRScheduler

from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent
from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit


class NoScheduler(BaseLRComponent):
Expand All @@ -17,12 +16,12 @@ class NoScheduler(BaseLRComponent):
"""
def __init__(
self,
step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.epoch,
random_state: Optional[np.random.RandomState] = None
):

super().__init__()
super().__init__(step_interval)
self.random_state = random_state
self.scheduler = None # type: Optional[_LRScheduler]

def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
import numpy as np

import torch.optim.lr_scheduler
from torch.optim.lr_scheduler import _LRScheduler

from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent
from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter


Expand All @@ -31,22 +31,26 @@ class ReduceLROnPlateau(BaseLRComponent):
factor (float): Factor by which the learning rate will be reduced. new_lr = lr * factor.
patience (int): Number of epochs with no improvement after which learning
rate will be reduced.
step_interval (str): step should be called after validation in the case of ReduceLROnPlateau
random_state (Optional[np.random.RandomState]): random state
Reference:
https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html#torch.optim.lr_scheduler.ReduceLROnPlateau
"""

def __init__(
self,
mode: str,
factor: float,
patience: int,
random_state: Optional[np.random.RandomState] = None
step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.valid,
random_state: Optional[np.random.RandomState] = None,
):
super().__init__()
super().__init__(step_interval)
self.mode = mode
self.factor = factor
self.patience = patience
self.random_state = random_state
self.scheduler = None # type: Optional[_LRScheduler]

def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent:
"""
Expand Down Expand Up @@ -93,7 +97,7 @@ def get_hyperparameter_search_space(
factor: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter='factor',
value_range=(0.01, 0.9),
default_value=0.1,
),
)
) -> ConfigurationSpace:

cs = ConfigurationSpace()
Expand Down
8 changes: 4 additions & 4 deletions autoPyTorch/pipeline/components/setup/lr_scheduler/StepLR.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import numpy as np

import torch.optim.lr_scheduler
from torch.optim.lr_scheduler import _LRScheduler

from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent
from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter


Expand All @@ -32,13 +32,13 @@ def __init__(
self,
step_size: int,
gamma: float,
step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.epoch,
random_state: Optional[np.random.RandomState] = None
):
super().__init__()
super().__init__(step_interval)
self.gamma = gamma
self.step_size = step_size
self.random_state = random_state
self.scheduler = None # type: Optional[_LRScheduler]

def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent:
"""
Expand Down Expand Up @@ -80,7 +80,7 @@ def get_hyperparameter_search_space(
step_size: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter='step_size',
value_range=(1, 10),
default_value=5,
),
)
) -> ConfigurationSpace:
cs = ConfigurationSpace()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent


directory = os.path.split(__file__)[0]
_schedulers = find_components(__package__,
directory,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,39 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from autoPyTorch.pipeline.components.setup.base_setup import autoPyTorchSetupComponent
from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit, StepIntervalUnitChoices
from autoPyTorch.utils.common import FitRequirement


class BaseLRComponent(autoPyTorchSetupComponent):
"""Provide an abstract interface for schedulers
in Auto-Pytorch"""

def __init__(self) -> None:
def __init__(self, step_interval: Union[str, StepIntervalUnit]):
super().__init__()
self.scheduler = None # type: Optional[_LRScheduler]
self._step_interval: StepIntervalUnit

if isinstance(step_interval, str):
if step_interval not in StepIntervalUnitChoices:
raise ValueError('step_interval must be either {}, but got {}.'.format(
StepIntervalUnitChoices,
step_interval
))
self._step_interval = getattr(StepIntervalUnit, step_interval)
else:
self._step_interval = step_interval

self.add_fit_requirements([
FitRequirement('optimizer', (Optimizer,), user_defined=False, dataset_property=False)])

@property
def step_interval(self) -> StepIntervalUnit:
return self._step_interval

def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
"""
Adds the scheduler into the fit dictionary 'X' and returns it.
Expand All @@ -26,7 +42,11 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
Returns:
(Dict[str, Any]): the updated 'X' dictionary
"""
X.update({'lr_scheduler': self.scheduler})

X.update(
lr_scheduler=self.scheduler,
step_interval=self.step_interval
)
return X

def get_scheduler(self) -> _LRScheduler:
Expand Down
17 changes: 17 additions & 0 deletions autoPyTorch/pipeline/components/setup/lr_scheduler/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from enum import Enum


class StepIntervalUnit(Enum):
"""
By which interval we perform the step for learning rate schedulers.
Attributes:
batch (str): We update every batch evaluation
epoch (str): We update every epoch
valid (str): We update every validation
"""
batch = 'batch'
epoch = 'epoch'
valid = 'valid'


StepIntervalUnitChoices = [step_interval.name for step_interval in StepIntervalUnit]
19 changes: 11 additions & 8 deletions autoPyTorch/pipeline/components/training/trainer/MixUpTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,34 @@ def __init__(self, alpha: float, weighted_loss: bool = False,
self.weighted_loss = weighted_loss
self.alpha = alpha

def data_preparation(self, X: np.ndarray, y: np.ndarray,
) -> typing.Tuple[np.ndarray, typing.Dict[str, np.ndarray]]:
def data_preparation(self, X: torch.Tensor, y: torch.Tensor,
) -> typing.Tuple[torch.Tensor, typing.Dict[str, np.ndarray]]:
"""
Depending on the trainer choice, data fed to the network might be pre-processed
on a different way. That is, in standard training we provide the data to the
network as we receive it to the loader. Some regularization techniques, like mixup
alter the data.
Args:
X (np.ndarray): The batch training features
y (np.ndarray): The batch training labels
X (torch.Tensor): The batch training features
y (torch.Tensor): The batch training labels
Returns:
np.ndarray: that processes data
torch.Tensor: that processes data
typing.Dict[str, np.ndarray]: arguments to the criterion function
TODO: Fix this typing. It is not np.ndarray.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

lam = self.random_state.beta(self.alpha, self.alpha) if self.alpha > 0. else 1.
batch_size = X.size()[0]
index = torch.randperm(batch_size).cuda() if X.is_cuda else torch.randperm(batch_size)
batch_size = X.shape[0]
index = torch.randperm(batch_size).to(device)

mixed_x = lam * X + (1 - lam) * X[index, :]
y_a, y_b = y, y[index]
return mixed_x, {'y_a': y_a, 'y_b': y_b, 'lam': lam}

def criterion_preparation(self, y_a: np.ndarray, y_b: np.ndarray = None, lam: float = 1.0
def criterion_preparation(self, y_a: torch.Tensor, y_b: torch.Tensor = None, lam: float = 1.0
) -> typing.Callable:
return lambda criterion, pred: lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

Expand Down
Loading

0 comments on commit 76bdde7

Please sign in to comment.