-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from dheemantha-bhat/tensorboard
Update tensorboard integration
- Loading branch information
Showing
5 changed files
with
248 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import os | ||
from typing import Dict | ||
|
||
import optuna | ||
from optuna._experimental import experimental_class | ||
from optuna_integration._imports import try_import | ||
from optuna.logging import get_logger | ||
|
||
|
||
Check warning on line 9 in optuna_integration/tensorboard.py
|
||
with try_import() as _imports: | ||
from tensorboard.plugins.hparams import api as hp | ||
import tensorflow as tf | ||
|
||
_logger = get_logger(__name__) | ||
|
||
|
||
@experimental_class("2.0.0") | ||
class TensorBoardCallback: | ||
"""Callback to track Optuna trials with TensorBoard. | ||
This callback adds relevant information that is tracked by Optuna to TensorBoard. | ||
See `the example <https://github.com/optuna/optuna-examples/blob/main/ | ||
tensorboard/tensorboard_simple.py>`_. | ||
Args: | ||
dirname: | ||
Directory to store TensorBoard logs. | ||
metric_name: | ||
Name of the metric. Since the metric itself is just a number, | ||
`metric_name` can be used to give it a name. So you know later | ||
if it was roc-auc or accuracy. | ||
""" | ||
|
||
def __init__(self, dirname: str, metric_name: str) -> None: | ||
_imports.check() | ||
self._dirname = dirname | ||
self._metric_name = metric_name | ||
self._hp_params: Dict[str, hp.HParam] = {} | ||
|
||
def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None: | ||
if len(self._hp_params) == 0: | ||
self._initialization(study) | ||
if trial.state != optuna.trial.TrialState.COMPLETE: | ||
return | ||
trial_value = trial.value if trial.value is not None else float("nan") | ||
hparams = {} | ||
for param_name, param_value in trial.params.items(): | ||
if param_name not in self._hp_params: | ||
self._add_distributions(trial.distributions) | ||
param = self._hp_params[param_name] | ||
if isinstance(param.domain, hp.Discrete): | ||
hparams[param] = param.domain.dtype(param_value) | ||
else: | ||
hparams[param] = param_value | ||
run_name = "trial-%d" % trial.number | ||
run_dir = os.path.join(self._dirname, run_name) | ||
with tf.summary.create_file_writer(run_dir).as_default(): | ||
hp.hparams(hparams, trial_id=run_name) # record the values used in this trial | ||
tf.summary.scalar(self._metric_name, trial_value, step=trial.number) | ||
|
||
def _add_distributions( | ||
self, distributions: Dict[str, optuna.distributions.BaseDistribution] | ||
) -> None: | ||
supported_distributions = ( | ||
optuna.distributions.CategoricalDistribution, | ||
optuna.distributions.FloatDistribution, | ||
optuna.distributions.IntDistribution, | ||
) | ||
|
||
for param_name, param_distribution in distributions.items(): | ||
if isinstance(param_distribution, optuna.distributions.FloatDistribution): | ||
self._hp_params[param_name] = hp.HParam( | ||
param_name, | ||
hp.RealInterval(float(param_distribution.low), float(param_distribution.high)), | ||
) | ||
elif isinstance(param_distribution, optuna.distributions.IntDistribution): | ||
self._hp_params[param_name] = hp.HParam( | ||
param_name, | ||
hp.IntInterval(param_distribution.low, param_distribution.high), | ||
) | ||
elif isinstance(param_distribution, optuna.distributions.CategoricalDistribution): | ||
choices = param_distribution.choices | ||
dtype = type(choices[0]) | ||
if any(not isinstance(choice, dtype) for choice in choices): | ||
_logger.warning( | ||
"Choices contains mixed types, which is not supported by TensorBoard. " | ||
"Converting all choices to strings." | ||
) | ||
choices = tuple(map(str, choices)) | ||
elif dtype not in (int, float, bool, str): | ||
_logger.warning( | ||
f"Choices are of type {dtype}, which is not supported by TensorBoard. " | ||
"Converting all choices to strings." | ||
) | ||
choices = tuple(map(str, choices)) | ||
|
||
self._hp_params[param_name] = hp.HParam( | ||
param_name, | ||
hp.Discrete(choices), | ||
) | ||
else: | ||
distribution_list = [ | ||
distribution.__name__ for distribution in supported_distributions | ||
] | ||
raise NotImplementedError( | ||
"The distribution {} is not implemented. " | ||
"The parameter distribution should be one of the {}".format( | ||
param_distribution, distribution_list | ||
) | ||
) | ||
|
||
def _initialization(self, study: optuna.Study) -> None: | ||
completed_trials = [ | ||
trial | ||
for trial in study.get_trials(deepcopy=False) | ||
if trial.state == optuna.trial.TrialState.COMPLETE | ||
] | ||
for trial in completed_trials: | ||
self._add_distributions(trial.distributions) | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,6 +57,7 @@ all = [ | |
"mxnet", | ||
"shap", | ||
"skorch", | ||
"tensorboard", | ||
"tensorflow", | ||
] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import os | ||
import shutil | ||
import tempfile | ||
|
||
import pytest | ||
|
||
import optuna | ||
|
||
from optuna_integration._imports import try_import | ||
from optuna_integration.tensorboard import TensorBoardCallback | ||
|
||
|
||
with try_import(): | ||
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator | ||
|
||
pytestmark = pytest.mark.integration | ||
|
||
|
||
def _objective_func(trial: optuna.trial.Trial) -> float: | ||
u = trial.suggest_int("u", 0, 10, step=2) | ||
v = trial.suggest_int("v", 1, 10, log=True) | ||
w = trial.suggest_float("w", -1.0, 1.0, step=0.1) | ||
x = trial.suggest_float("x", -1.0, 1.0) | ||
y = trial.suggest_float("y", 20.0, 30.0, log=True) | ||
z = trial.suggest_categorical("z", (-1.0, 1.0)) | ||
trial.set_user_attr("my_user_attr", "my_user_attr_value") | ||
return u + v + w + (x - 2) ** 2 + (y - 25) ** 2 + z | ||
|
||
|
||
def test_study_name() -> None: | ||
dirname = tempfile.mkdtemp() | ||
metric_name = "target" | ||
study_name = "test_tensorboard_integration" | ||
|
||
tbcallback = TensorBoardCallback(dirname, metric_name) | ||
study = optuna.create_study(study_name=study_name) | ||
study.optimize(_objective_func, n_trials=1, callbacks=[tbcallback]) | ||
|
||
event_acc = EventAccumulator(os.path.join(dirname, "trial-0")) | ||
event_acc.Reload() | ||
|
||
try: | ||
assert len(event_acc.Tensors("target")) == 1 | ||
except Exception as e: | ||
raise e | ||
finally: | ||
shutil.rmtree(dirname) | ||
|
||
|
||
def test_cast_float() -> None: | ||
def objective(trial: optuna.trial.Trial) -> float: | ||
x = trial.suggest_float("x", 1, 2) | ||
y = trial.suggest_float("y", 1, 2, log=True) | ||
assert isinstance(x, float) | ||
assert isinstance(y, float) | ||
return x + y | ||
|
||
dirname = tempfile.mkdtemp() | ||
metric_name = "target" | ||
study_name = "test_tensorboard_integration" | ||
|
||
tbcallback = TensorBoardCallback(dirname, metric_name) | ||
study = optuna.create_study(study_name=study_name) | ||
study.optimize(objective, n_trials=1, callbacks=[tbcallback]) | ||
|
||
|
||
def test_categorical() -> None: | ||
def objective(trial: optuna.trial.Trial) -> float: | ||
x = trial.suggest_categorical("x", [1, 2, 3]) | ||
assert isinstance(x, int) | ||
return x | ||
|
||
dirname = tempfile.mkdtemp() | ||
metric_name = "target" | ||
study_name = "test_tensorboard_integration" | ||
|
||
tbcallback = TensorBoardCallback(dirname, metric_name) | ||
study = optuna.create_study(study_name=study_name) | ||
study.optimize(objective, n_trials=1, callbacks=[tbcallback]) | ||
|
||
|
||
def test_categorical_mixed_types() -> None: | ||
def objective(trial: optuna.trial.Trial) -> float: | ||
x = trial.suggest_categorical("x", [None, 1, 2, 3.14, True, "foo"]) | ||
assert x is None or isinstance(x, (int, float, bool, str)) | ||
return len(str(x)) | ||
|
||
dirname = tempfile.mkdtemp() | ||
metric_name = "target" | ||
study_name = "test_tensorboard_integration" | ||
|
||
tbcallback = TensorBoardCallback(dirname, metric_name) | ||
study = optuna.create_study(study_name=study_name) | ||
study.optimize(objective, n_trials=10, callbacks=[tbcallback]) | ||
|
||
|
||
def test_categorical_unsupported_types() -> None: | ||
def objective(trial: optuna.trial.Trial) -> float: | ||
x = trial.suggest_categorical("x", [[1, 2], [3, 4, 5], [6]]) # type: ignore[list-item] | ||
assert isinstance(x, list) | ||
return len(x) | ||
|
||
dirname = tempfile.mkdtemp() | ||
metric_name = "target" | ||
study_name = "test_tensorboard_integration" | ||
|
||
tbcallback = TensorBoardCallback(dirname, metric_name) | ||
study = optuna.create_study(study_name=study_name) | ||
study.optimize(objective, n_trials=10, callbacks=[tbcallback]) | ||
|
||
|
||
def test_experimental_warning() -> None: | ||
with pytest.warns(optuna.exceptions.ExperimentalWarning): | ||
TensorBoardCallback(dirname="", metric_name="") |