Skip to content

Commit

Permalink
Merge pull request #1 from dheemantha-bhat/tensorboard
Browse files Browse the repository at this point in the history
Update tensorboard integration
  • Loading branch information
dheemantha-bhat authored Dec 22, 2023
2 parents 58638a1 + b4bf450 commit 5e397fe
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Optuna-Integration API reference is [here](https://optuna-integration.readthedoc
* [MXNet](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#mxnet) ([example](https://github.com/optuna/optuna-examples/tree/main/mxnet))
* [SHAP](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#shap)
* [skorch](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#skorch) ([example](https://github.com/optuna/optuna-examples/tree/main/pytorch/skorch_simple.py))
* [TensorBoard](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#tensorboard) ([example](https://github.com/optuna/optuna-examples/tree/main/tensorboard/tensorboard_simple.py))
* [tf.keras](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#tensorflow) ([example](https://github.com/optuna/optuna-examples/tree/main/tfkeras/tfkeras_integration.py))

## Installation
Expand Down
12 changes: 11 additions & 1 deletion docs/source/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,21 @@ skorch

optuna.integration.SkorchPruningCallback

TensorBoard
----------

.. autosummary::
:toctree: generated/
:nosignatures:

optuna.integration.TensorBoardCallback

TensorFlow
----------

.. autosummary::
:toctree: generated/
:nosignatures:

optuna.integration.TFKerasPruningCallback
optuna.integration.TFKerasPruningCallback

121 changes: 121 additions & 0 deletions optuna_integration/tensorboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import os
from typing import Dict

Check warning on line 2 in optuna_integration/tensorboard.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/tensorboard.py#L1-L2

Added lines #L1 - L2 were not covered by tests

import optuna
from optuna._experimental import experimental_class
from optuna_integration._imports import try_import

Check warning on line 6 in optuna_integration/tensorboard.py

View workflow job for this annotation

GitHub Actions / reviewdog

[formatters] reported by reviewdog 🐶 Raw Output: optuna_integration/tensorboard.py:6:-from optuna_integration._imports import try_import
from optuna.logging import get_logger

Check warning on line 7 in optuna_integration/tensorboard.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/tensorboard.py#L4-L7

Added lines #L4 - L7 were not covered by tests


Check warning on line 9 in optuna_integration/tensorboard.py

View workflow job for this annotation

GitHub Actions / reviewdog

[formatters] reported by reviewdog 🐶 Raw Output: optuna_integration/tensorboard.py:8:+from optuna_integration._imports import try_import optuna_integration/tensorboard.py:9:+
with try_import() as _imports:
from tensorboard.plugins.hparams import api as hp
import tensorflow as tf

Check warning on line 12 in optuna_integration/tensorboard.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/tensorboard.py#L10-L12

Added lines #L10 - L12 were not covered by tests

_logger = get_logger(__name__)

Check warning on line 14 in optuna_integration/tensorboard.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/tensorboard.py#L14

Added line #L14 was not covered by tests


@experimental_class("2.0.0")
class TensorBoardCallback:

Check warning on line 18 in optuna_integration/tensorboard.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/tensorboard.py#L17-L18

Added lines #L17 - L18 were not covered by tests
"""Callback to track Optuna trials with TensorBoard.
This callback adds relevant information that is tracked by Optuna to TensorBoard.
See `the example <https://github.com/optuna/optuna-examples/blob/main/
tensorboard/tensorboard_simple.py>`_.
Args:
dirname:
Directory to store TensorBoard logs.
metric_name:
Name of the metric. Since the metric itself is just a number,
`metric_name` can be used to give it a name. So you know later
if it was roc-auc or accuracy.
"""

def __init__(self, dirname: str, metric_name: str) -> None:
_imports.check()
self._dirname = dirname
self._metric_name = metric_name
self._hp_params: Dict[str, hp.HParam] = {}

Check warning on line 40 in optuna_integration/tensorboard.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/tensorboard.py#L36-L40

Added lines #L36 - L40 were not covered by tests

def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:
if len(self._hp_params) == 0:
self._initialization(study)
if trial.state != optuna.trial.TrialState.COMPLETE:
return
trial_value = trial.value if trial.value is not None else float("nan")
hparams = {}
for param_name, param_value in trial.params.items():
if param_name not in self._hp_params:
self._add_distributions(trial.distributions)
param = self._hp_params[param_name]
if isinstance(param.domain, hp.Discrete):
hparams[param] = param.domain.dtype(param_value)

Check warning on line 54 in optuna_integration/tensorboard.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/tensorboard.py#L42-L54

Added lines #L42 - L54 were not covered by tests
else:
hparams[param] = param_value
run_name = "trial-%d" % trial.number
run_dir = os.path.join(self._dirname, run_name)
with tf.summary.create_file_writer(run_dir).as_default():
hp.hparams(hparams, trial_id=run_name) # record the values used in this trial
tf.summary.scalar(self._metric_name, trial_value, step=trial.number)

Check warning on line 61 in optuna_integration/tensorboard.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/tensorboard.py#L56-L61

Added lines #L56 - L61 were not covered by tests

def _add_distributions(

Check warning on line 63 in optuna_integration/tensorboard.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/tensorboard.py#L63

Added line #L63 was not covered by tests
self, distributions: Dict[str, optuna.distributions.BaseDistribution]
) -> None:
supported_distributions = (

Check warning on line 66 in optuna_integration/tensorboard.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/tensorboard.py#L66

Added line #L66 was not covered by tests
optuna.distributions.CategoricalDistribution,
optuna.distributions.FloatDistribution,
optuna.distributions.IntDistribution,
)

for param_name, param_distribution in distributions.items():
if isinstance(param_distribution, optuna.distributions.FloatDistribution):
self._hp_params[param_name] = hp.HParam(

Check warning on line 74 in optuna_integration/tensorboard.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/tensorboard.py#L72-L74

Added lines #L72 - L74 were not covered by tests
param_name,
hp.RealInterval(float(param_distribution.low), float(param_distribution.high)),
)
elif isinstance(param_distribution, optuna.distributions.IntDistribution):
self._hp_params[param_name] = hp.HParam(

Check warning on line 79 in optuna_integration/tensorboard.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/tensorboard.py#L78-L79

Added lines #L78 - L79 were not covered by tests
param_name,
hp.IntInterval(param_distribution.low, param_distribution.high),
)
elif isinstance(param_distribution, optuna.distributions.CategoricalDistribution):
choices = param_distribution.choices
dtype = type(choices[0])
if any(not isinstance(choice, dtype) for choice in choices):
_logger.warning(

Check warning on line 87 in optuna_integration/tensorboard.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/tensorboard.py#L83-L87

Added lines #L83 - L87 were not covered by tests
"Choices contains mixed types, which is not supported by TensorBoard. "
"Converting all choices to strings."
)
choices = tuple(map(str, choices))
elif dtype not in (int, float, bool, str):
_logger.warning(

Check warning on line 93 in optuna_integration/tensorboard.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/tensorboard.py#L91-L93

Added lines #L91 - L93 were not covered by tests
f"Choices are of type {dtype}, which is not supported by TensorBoard. "
"Converting all choices to strings."
)
choices = tuple(map(str, choices))

Check warning on line 97 in optuna_integration/tensorboard.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/tensorboard.py#L97

Added line #L97 was not covered by tests

self._hp_params[param_name] = hp.HParam(

Check warning on line 99 in optuna_integration/tensorboard.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/tensorboard.py#L99

Added line #L99 was not covered by tests
param_name,
hp.Discrete(choices),
)
else:
distribution_list = [

Check warning on line 104 in optuna_integration/tensorboard.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/tensorboard.py#L104

Added line #L104 was not covered by tests
distribution.__name__ for distribution in supported_distributions
]
raise NotImplementedError(

Check warning on line 107 in optuna_integration/tensorboard.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/tensorboard.py#L107

Added line #L107 was not covered by tests
"The distribution {} is not implemented. "
"The parameter distribution should be one of the {}".format(
param_distribution, distribution_list
)
)

def _initialization(self, study: optuna.Study) -> None:
completed_trials = [

Check warning on line 115 in optuna_integration/tensorboard.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/tensorboard.py#L114-L115

Added lines #L114 - L115 were not covered by tests
trial
for trial in study.get_trials(deepcopy=False)
if trial.state == optuna.trial.TrialState.COMPLETE
]
for trial in completed_trials:
self._add_distributions(trial.distributions)

Check warning on line 121 in optuna_integration/tensorboard.py

View check run for this annotation

Codecov / codecov/patch

optuna_integration/tensorboard.py#L120-L121

Added lines #L120 - L121 were not covered by tests
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ all = [
"mxnet",
"shap",
"skorch",
"tensorboard",
"tensorflow",
]

Expand Down
114 changes: 114 additions & 0 deletions tests/test_tensorboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import os
import shutil
import tempfile

import pytest

Check warning on line 5 in tests/test_tensorboard.py

View workflow job for this annotation

GitHub Actions / reviewdog

[formatters] reported by reviewdog 🐶 Raw Output: tests/test_tensorboard.py:5:-import pytest tests/test_tensorboard.py:6:-

import optuna

Check warning on line 8 in tests/test_tensorboard.py

View workflow job for this annotation

GitHub Actions / reviewdog

[formatters] reported by reviewdog 🐶 Raw Output: tests/test_tensorboard.py:6:+import pytest
from optuna_integration._imports import try_import
from optuna_integration.tensorboard import TensorBoardCallback


with try_import():
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

pytestmark = pytest.mark.integration


def _objective_func(trial: optuna.trial.Trial) -> float:
u = trial.suggest_int("u", 0, 10, step=2)
v = trial.suggest_int("v", 1, 10, log=True)
w = trial.suggest_float("w", -1.0, 1.0, step=0.1)
x = trial.suggest_float("x", -1.0, 1.0)
y = trial.suggest_float("y", 20.0, 30.0, log=True)
z = trial.suggest_categorical("z", (-1.0, 1.0))
trial.set_user_attr("my_user_attr", "my_user_attr_value")
return u + v + w + (x - 2) ** 2 + (y - 25) ** 2 + z


def test_study_name() -> None:
dirname = tempfile.mkdtemp()
metric_name = "target"
study_name = "test_tensorboard_integration"

tbcallback = TensorBoardCallback(dirname, metric_name)
study = optuna.create_study(study_name=study_name)
study.optimize(_objective_func, n_trials=1, callbacks=[tbcallback])

event_acc = EventAccumulator(os.path.join(dirname, "trial-0"))
event_acc.Reload()

try:
assert len(event_acc.Tensors("target")) == 1
except Exception as e:
raise e
finally:
shutil.rmtree(dirname)


def test_cast_float() -> None:
def objective(trial: optuna.trial.Trial) -> float:
x = trial.suggest_float("x", 1, 2)
y = trial.suggest_float("y", 1, 2, log=True)
assert isinstance(x, float)
assert isinstance(y, float)
return x + y

dirname = tempfile.mkdtemp()
metric_name = "target"
study_name = "test_tensorboard_integration"

tbcallback = TensorBoardCallback(dirname, metric_name)
study = optuna.create_study(study_name=study_name)
study.optimize(objective, n_trials=1, callbacks=[tbcallback])


def test_categorical() -> None:
def objective(trial: optuna.trial.Trial) -> float:
x = trial.suggest_categorical("x", [1, 2, 3])
assert isinstance(x, int)
return x

dirname = tempfile.mkdtemp()
metric_name = "target"
study_name = "test_tensorboard_integration"

tbcallback = TensorBoardCallback(dirname, metric_name)
study = optuna.create_study(study_name=study_name)
study.optimize(objective, n_trials=1, callbacks=[tbcallback])


def test_categorical_mixed_types() -> None:
def objective(trial: optuna.trial.Trial) -> float:
x = trial.suggest_categorical("x", [None, 1, 2, 3.14, True, "foo"])
assert x is None or isinstance(x, (int, float, bool, str))
return len(str(x))

dirname = tempfile.mkdtemp()
metric_name = "target"
study_name = "test_tensorboard_integration"

tbcallback = TensorBoardCallback(dirname, metric_name)
study = optuna.create_study(study_name=study_name)
study.optimize(objective, n_trials=10, callbacks=[tbcallback])


def test_categorical_unsupported_types() -> None:
def objective(trial: optuna.trial.Trial) -> float:
x = trial.suggest_categorical("x", [[1, 2], [3, 4, 5], [6]]) # type: ignore[list-item]
assert isinstance(x, list)
return len(x)

dirname = tempfile.mkdtemp()
metric_name = "target"
study_name = "test_tensorboard_integration"

tbcallback = TensorBoardCallback(dirname, metric_name)
study = optuna.create_study(study_name=study_name)
study.optimize(objective, n_trials=10, callbacks=[tbcallback])


def test_experimental_warning() -> None:
with pytest.warns(optuna.exceptions.ExperimentalWarning):
TensorBoardCallback(dirname="", metric_name="")

0 comments on commit 5e397fe

Please sign in to comment.