Skip to content

Commit

Permalink
Merge pull request optuna#5281 from nzw0301/unify-timeline-plot-test
Browse files Browse the repository at this point in the history
Unify and refactor `plot_timeline` test
  • Loading branch information
eukaryo authored Mar 13, 2024
2 parents 4d3f7cf + 39d3325 commit c2ec044
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 43 deletions.
23 changes: 0 additions & 23 deletions tests/visualization_tests/matplotlib_tests/test_timeline.py

This file was deleted.

65 changes: 45 additions & 20 deletions tests/visualization_tests/test_timeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from io import BytesIO
import time
from typing import Any
from typing import Callable

import _pytest.capture
import pytest
Expand All @@ -12,23 +13,33 @@
from optuna.samplers._base import _CONSTRAINTS_KEY
from optuna.study.study import Study
from optuna.trial import TrialState
from optuna.visualization import plot_timeline as plotly_plot_timeline
from optuna.visualization._plotly_imports import _imports as plotly_imports
from optuna.visualization._timeline import _get_timeline_info
from optuna.visualization.matplotlib import plot_timeline as plt_plot_timeline
from optuna.visualization.matplotlib._matplotlib_imports import _imports as plt_imports


if plotly_imports.is_successful():
from optuna.visualization._plotly_imports import go

from optuna.visualization._timeline import _get_timeline_info
from optuna.visualization._timeline import plot_timeline
if plt_imports.is_successful():
from optuna.visualization.matplotlib._matplotlib_imports import plt


parametrize_plot_timeline = pytest.mark.parametrize(
"plot_timeline",
[plotly_plot_timeline, plt_plot_timeline],
)


def _create_study(
trial_states_list: list[TrialState],
trial_states: list[TrialState],
trial_sys_attrs: dict[str, Any] | None = None,
) -> Study:
study = optuna.create_study()
fmax = float(len(trial_states_list))
for i, s in enumerate(trial_states_list):
fmax = float(len(trial_states))
for i, s in enumerate(trial_states):
study.add_trial(
optuna.trial.create_trial(
params={"x": float(i)},
Expand Down Expand Up @@ -115,26 +126,36 @@ def test_get_timeline_info_negative_elapsed_time(capsys: _pytest.capture.Capture
assert bar.complete < bar.start


@parametrize_plot_timeline
@pytest.mark.parametrize(
"trial_states_list",
"trial_states",
[
[],
[TrialState.COMPLETE, TrialState.PRUNED, TrialState.FAIL, TrialState.RUNNING],
[TrialState.RUNNING, TrialState.FAIL, TrialState.PRUNED, TrialState.COMPLETE],
],
)
def test_get_timeline_plot(trial_states_list: list[TrialState]) -> None:
study = _create_study(trial_states_list)
fig = plot_timeline(study)
assert type(fig) is go.Figure
fig.write_image(BytesIO())
def test_get_timeline_plot(
plot_timeline: Callable[..., Any], trial_states: list[TrialState]
) -> None:
study = _create_study(trial_states)
figure = plot_timeline(study)

if isinstance(figure, go.Figure):
figure.write_image(BytesIO())
else:
plt.savefig(BytesIO())
plt.close()


@parametrize_plot_timeline
@pytest.mark.parametrize("waiting_time", [0.0, 1.5])
def test_get_timeline_plot_with_killed_running_trials(waiting_time: float) -> None:
def test_get_timeline_plot_with_killed_running_trials(
plot_timeline: Callable[..., Any], waiting_time: float
) -> None:
def _objective_with_sleep(trial: optuna.Trial) -> float:
time.sleep(0.1)
trial.suggest_float("x", -1, 1)
trial.suggest_float("x", -1.0, 1.0)
return 1.0

study = optuna.create_study()
Expand All @@ -148,10 +169,14 @@ def _objective_with_sleep(trial: optuna.Trial) -> float:
study.optimize(_objective_with_sleep, n_trials=2)

time.sleep(waiting_time)
fig = plot_timeline(study)
bar_colors = [d["marker"]["color"] for d in fig["data"]]
assert "green" in bar_colors, "Running trial, i.e. green color, must be included."
bar_length_in_milliseconds = fig["data"][1]["x"][0]
# If the waiting time is too long, stop the timeline plots for running trials.
assert waiting_time < 1.0 or bar_length_in_milliseconds < waiting_time * 1000
fig.write_image(BytesIO())
figure = plot_timeline(study)

if isinstance(figure, go.Figure):
bar_colors = [d["marker"]["color"] for d in figure["data"]]
assert "green" in bar_colors, "Running trial, i.e. green color, must be included."
bar_length_in_milliseconds = figure["data"][1]["x"][0]
# If the waiting time is too long, stop the timeline plots for running trials.
assert waiting_time < 1.0 or bar_length_in_milliseconds < waiting_time * 1000
figure.write_image(BytesIO())
else:
pytest.skip("Matplotlib test is unimplemented.")

0 comments on commit c2ec044

Please sign in to comment.