-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2b139de
commit 5f07024
Showing
7 changed files
with
469 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2024 Shuhei Watanabe | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
--- | ||
author: Shuhei Watanabe | ||
title: c-TPE; Tree-structured Parzen Estimator with Inequality Constraints for Expensive Hyperparameter Optimization | ||
description: The optimizer that reproduces the algorithm described in the paper ``c-TPE; Tree-structured Parzen Estimator with Inequality Constraints for Expensive Hyperparameter Optimization''. | ||
tags: [sampler, tpe, c-tpe, paper, research] | ||
optuna_versions: [v4.1.0] | ||
license: MIT License | ||
--- | ||
|
||
## Abstract | ||
|
||
This package aims to reproduce the TPE algorithm used in the paper published at IJCAI'23: | ||
|
||
- [c-TPE: Tree-structured Parzen Estimator with Inequality Constraints for Expensive Hyperparameter Optimization](https://arxiv.org/abs/2211.14411) | ||
|
||
The default parameter set of this sampler is the recommended setup from the paper and the experiments in the paper can also be reproduced by this sampler. | ||
|
||
## Class or Function Names | ||
|
||
- cTPESampler | ||
|
||
## Installation | ||
|
||
The version constraint of this package is Optuna v4.0.0 or later. | ||
|
||
```shell | ||
# The requirement is only Optuna. | ||
$ pip install optuna | ||
|
||
# You can also optionally install as follows: | ||
$ pip install -r https://hub.optuna.org/samplers/ctpe/requirements.txt | ||
``` | ||
|
||
## Example | ||
|
||
TODO: Change here. | ||
|
||
This sampler supports the arguments discussed in [the original paper](https://arxiv.org/abs/2304.11127) and can be used in this way. | ||
|
||
```python | ||
import numpy as np | ||
|
||
import optuna | ||
|
||
import optunahub | ||
|
||
|
||
def objective(trial): | ||
x = trial.suggest_float("x", -5, 5) | ||
y = trial.suggest_int("y", -5, 5) | ||
z = trial.suggest_categorical("z", ["a", "aa", "aaa"]) | ||
return len(z) * (x**2 + y**2) | ||
|
||
|
||
module = optunahub.load_module(package="samplers/tpe_tutorial") | ||
optuna.logging.set_verbosity(optuna.logging.CRITICAL) | ||
arg_choices = { | ||
"consider_prior": [True, False], | ||
"consider_magic_clip": [True, False], | ||
"multivariate": [True, False], | ||
"b_magic_exponent": [0.5, 1.0, 2.0, np.inf], | ||
"min_bandwidth_factor": [0.01, 0.1], | ||
"gamma_strategy": ["linear", "sqrt"], | ||
"weight_strategy": ["uniform", "old-decay", "old-drop", "EI"], | ||
"bandwidth_strategy": ["optuna", "hyperopt", "scott"], | ||
"categorical_prior_weight": [0.1, None], | ||
} | ||
for arg_name, choices in arg_choices.items(): | ||
results = [] | ||
for choice in choices: | ||
print(arg_name, choice) | ||
sampler = module.CustomizableTPESampler(seed=0, **{arg_name: choice}) | ||
study = optuna.create_study(sampler=sampler) | ||
study.optimize(objective, n_trials=100 if arg_name != "b_magic_exponent" else 200) | ||
results.append(study.trials[-1].value) | ||
|
||
print(f"Did every setup yield different results for {arg_name}?: {len(set(results)) == len(results)}") | ||
|
||
``` | ||
|
||
In the paper, the following arguments, which do not exist in Optuna, were researched: | ||
|
||
- `gamma_strategy`: The splitting algorithm in Table 3. The choices are `linear` and `sqrt`. | ||
- `gamma_beta`: The beta parameter for the splitting algorithm in Table 3. This value must be positive. | ||
- `weight_strategy`: The weighting algorithm in Table 3. The choices are `uniform`, `old-decay`, `old-drop`, and `EI`. | ||
- `categorical_prior_weight`: The categorical bandwidth in Table 3. If `None`, the Optuna default algorithm will be used. | ||
- `bandwidth_strategy`: The bandwidth selection heuristic in Table 6. The choices are `optuna`, `hyperopt`, and `scott`. | ||
- `min_bandwidth_factor`: The minimum bandwidth factor in Table 6. This value must be positive. | ||
- `b_magic_exponent`: The exponent alpha in Table 6. Optuna takes 1.0 by default. | ||
|
||
For more details, please check [the paper](https://arxiv.org/abs/2304.11127). | ||
|
||
### Bibtex | ||
|
||
When you use this sampler, please cite the following: | ||
|
||
```bibtex | ||
@inproceedings{watanabe_ctpe_ijcai_2023, | ||
title={{c-TPE}: Tree-structured {P}arzen Estimator with Inequality Constraints for Expensive Hyperparameter Optimization}, | ||
author={Watanabe, Shuhei and Hutter, Frank}, | ||
booktitle={International Joint Conference on Artificial Intelligence}, | ||
year={2023} | ||
} | ||
@inproceedings{watanabe_ctpe_workshop_2022, | ||
title={{c-TPE}: Generalizing Tree-structured {P}arzen Estimator with Inequality Constraints for Continuous and Categorical Hyperparameter Optimization}, | ||
author={Watanabe, Shuhei and Hutter, Frank}, | ||
journal = {Gaussian Processes, Spatiotemporal Modeling, and Decision-making Systems Workshop at Advances in Neural Information Processing Systems}, | ||
year={2022} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .sampler import cTPESampler | ||
|
||
|
||
__all__ = ["cTPESampler"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
|
||
|
||
class GammaFunc: | ||
def __init__(self, strategy: str, beta: float): | ||
strategy_choices = ["linear", "sqrt"] | ||
if strategy not in strategy_choices: | ||
raise ValueError(f"strategy must be in {strategy_choices}, but got {strategy}.") | ||
|
||
self._strategy = strategy | ||
self._beta = beta | ||
|
||
def __call__(self, x: int) -> int: | ||
if self._strategy == "linear": | ||
n = int(np.ceil(self._beta * x)) | ||
elif self._strategy == "sqrt": | ||
n = int(np.ceil(self._beta * np.sqrt(x))) | ||
else: | ||
assert "Should not reach." | ||
|
||
return min(n, 25) | ||
|
||
|
||
class WeightFunc: | ||
def __init__(self, strategy: str): | ||
strategy_choices = ["old-decay", "old-drop", "uniform"] | ||
if strategy not in strategy_choices: | ||
raise ValueError(f"strategy must be in {strategy_choices}, but got {strategy}.") | ||
|
||
self._strategy = strategy | ||
|
||
def __call__(self, x: int) -> np.ndarray: | ||
if x == 0: | ||
return np.asarray([]) | ||
elif x < 25 or self._strategy == "uniform": | ||
return np.ones(x) | ||
elif self._strategy == "old-decay": | ||
ramp = np.linspace(1.0 / x, 1.0, num=x - 25) | ||
flat = np.ones(25) | ||
return np.concatenate([ramp, flat], axis=0) | ||
elif self._strategy == "old-drop": | ||
weights = np.ones(x) | ||
weights[:-25] = 1e-12 | ||
return weights | ||
else: | ||
assert "Should not reach." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
from __future__ import annotations | ||
|
||
from collections.abc import Callable | ||
from typing import NamedTuple | ||
|
||
import numpy as np | ||
from optuna.distributions import CategoricalDistribution | ||
from optuna.samplers._tpe.parzen_estimator import _ParzenEstimator | ||
from optuna.samplers._tpe.probability_distributions import _BatchedCategoricalDistributions | ||
from optuna.samplers._tpe.probability_distributions import _BatchedDiscreteTruncNormDistributions | ||
from optuna.samplers._tpe.probability_distributions import _BatchedDistributions | ||
from optuna.samplers._tpe.probability_distributions import _BatchedTruncNormDistributions | ||
|
||
|
||
class _CustomizableParzenEstimatorParameters(NamedTuple): | ||
consider_prior: bool | ||
prior_weight: float | None | ||
consider_magic_clip: bool | ||
weights: Callable[[int], np.ndarray] | ||
multivariate: bool | ||
b_magic_exponent: float | ||
min_bandwidth_factor: float | ||
bandwidth_strategy: str | ||
categorical_prior_weight: float | None | ||
|
||
|
||
def _bandwidth_hyperopt( | ||
mus: np.ndarray, | ||
low: float, | ||
high: float, | ||
step: float | None, | ||
) -> np.ndarray: | ||
step_or_0 = step or 0 | ||
sorted_indices = np.argsort(mus) | ||
sorted_mus_with_endpoints = np.empty(len(mus) + 2, dtype=float) | ||
sorted_mus_with_endpoints[0] = low - step_or_0 / 2 | ||
sorted_mus_with_endpoints[1:-1] = mus[sorted_indices] | ||
sorted_mus_with_endpoints[-1] = high + step_or_0 / 2 | ||
sorted_sigmas = np.maximum( | ||
sorted_mus_with_endpoints[1:-1] - sorted_mus_with_endpoints[0:-2], | ||
sorted_mus_with_endpoints[2:] - sorted_mus_with_endpoints[1:-1], | ||
) | ||
return sorted_sigmas[np.argsort(sorted_indices)] | ||
|
||
|
||
def _bandwidth_optuna( | ||
n_observations: int, | ||
consider_prior: bool, | ||
domain_range: float, | ||
dim: int, | ||
) -> np.ndarray: | ||
SIGMA0_MAGNITUDE = 0.2 | ||
sigma = SIGMA0_MAGNITUDE * max(n_observations, 1) ** (-1.0 / (dim + 4)) * domain_range | ||
return np.full(shape=(n_observations + consider_prior,), fill_value=sigma) | ||
|
||
|
||
def _bandwidth_scott(mus: np.ndarray) -> np.ndarray: | ||
std = np.std(mus, ddof=int(mus.size > 1)) | ||
IQR = np.subtract.reduce(np.percentile(mus, [75, 25])) | ||
return np.full_like(mus, 1.059 * min(IQR / 1.34, std) * mus.size**-0.2) | ||
|
||
|
||
def _clip_bandwidth( | ||
sigmas: np.ndarray, | ||
n_observations: int, | ||
domain_range: float, | ||
consider_prior: bool, | ||
consider_magic_clip: bool, | ||
b_magic_exponent: float, | ||
min_bandwidth_factor: float, | ||
) -> np.ndarray: | ||
# We adjust the range of the 'sigmas' according to the 'consider_magic_clip' flag. | ||
maxsigma = 1.0 * domain_range | ||
if consider_magic_clip: | ||
bandwidth_factor = max( | ||
min_bandwidth_factor, 1.0 / (1 + n_observations + consider_prior) ** b_magic_exponent | ||
) | ||
minsigma = bandwidth_factor * domain_range | ||
else: | ||
minsigma = 1e-12 | ||
|
||
clipped_sigmas = np.asarray(np.clip(sigmas, minsigma, maxsigma)) | ||
if consider_prior: | ||
clipped_sigmas[-1] = maxsigma | ||
|
||
return clipped_sigmas | ||
|
||
|
||
class _CustomizableParzenEstimator(_ParzenEstimator): | ||
def _calculate_numerical_distributions( | ||
self, | ||
observations: np.ndarray, | ||
low: float, | ||
high: float, | ||
step: float | None, | ||
parameters: _CustomizableParzenEstimatorParameters, | ||
) -> _BatchedDistributions: | ||
domain_range = high - low + (step or 0) | ||
consider_prior = parameters.consider_prior or len(observations) == 0 | ||
|
||
if consider_prior: | ||
mus = np.append(observations, [0.5 * (low + high)]) | ||
else: | ||
mus = observations.copy() | ||
|
||
if parameters.bandwidth_strategy == "hyperopt": | ||
sigmas = _bandwidth_hyperopt(mus, low, high, step) | ||
elif parameters.bandwidth_strategy == "optuna": | ||
sigmas = _bandwidth_optuna( | ||
n_observations=len(observations), | ||
consider_prior=consider_prior, | ||
domain_range=domain_range, | ||
dim=len(self._search_space), | ||
) | ||
elif parameters.bandwidth_strategy == "scott": | ||
sigmas = _bandwidth_scott(mus) | ||
else: | ||
raise ValueError(f"Got unknown bandwidth_strategy={parameters.bandwidth_strategy}.") | ||
|
||
sigmas = _clip_bandwidth( | ||
sigmas=sigmas, | ||
n_observations=len(observations), | ||
domain_range=domain_range, | ||
consider_magic_clip=parameters.consider_magic_clip, | ||
consider_prior=consider_prior, | ||
b_magic_exponent=parameters.b_magic_exponent, | ||
min_bandwidth_factor=parameters.min_bandwidth_factor, | ||
) | ||
|
||
if step is None: | ||
return _BatchedTruncNormDistributions(mus, sigmas, low, high) | ||
else: | ||
return _BatchedDiscreteTruncNormDistributions(mus, sigmas, low, high, step) | ||
|
||
def _calculate_categorical_distributions( | ||
self, | ||
observations: np.ndarray, | ||
param_name: str, | ||
search_space: CategoricalDistribution, | ||
parameters: _CustomizableParzenEstimatorParameters, | ||
) -> _BatchedDistributions: | ||
choices = search_space.choices | ||
n_choices = len(choices) | ||
if len(observations) == 0: | ||
return _BatchedCategoricalDistributions( | ||
weights=np.full((1, n_choices), fill_value=1.0 / n_choices) | ||
) | ||
|
||
n_kernels = len(observations) + parameters.consider_prior | ||
observed_indices = observations.astype(int) | ||
if parameters.categorical_prior_weight is None: | ||
weights = np.full(shape=(n_kernels, n_choices), fill_value=1.0 / n_kernels) | ||
weights[np.arange(len(observed_indices)), observed_indices] += 1 | ||
weights /= weights.sum(axis=1, keepdims=True) | ||
else: | ||
assert 0 <= parameters.categorical_prior_weight <= 1 | ||
b = parameters.categorical_prior_weight | ||
weights = np.full(shape=(n_kernels, n_choices), fill_value=b / (n_choices - 1)) | ||
weights[np.arange(len(observed_indices)), observed_indices] = 1 - b | ||
weights[-1] = 1.0 / n_choices | ||
|
||
return _BatchedCategoricalDistributions(weights) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
optuna>=4.0.0 |
Oops, something went wrong.