diff --git a/tests/visualization_tests/matplotlib_tests/test_timeline.py b/tests/visualization_tests/matplotlib_tests/test_timeline.py deleted file mode 100644 index 84abae7a23..0000000000 --- a/tests/visualization_tests/matplotlib_tests/test_timeline.py +++ /dev/null @@ -1,23 +0,0 @@ -from __future__ import annotations - -from io import BytesIO - -import pytest - -from optuna.trial import TrialState -from optuna.visualization.matplotlib._timeline import plot_timeline -from tests.visualization_tests.test_timeline import _create_study - - -@pytest.mark.parametrize( - "trial_states_list", - [ - [], - [TrialState.COMPLETE, TrialState.PRUNED, TrialState.FAIL, TrialState.RUNNING], - [TrialState.RUNNING, TrialState.FAIL, TrialState.PRUNED, TrialState.COMPLETE], - ], -) -def test_get_timeline_plot(trial_states_list: list[TrialState]) -> None: - study = _create_study(trial_states_list) - fig = plot_timeline(study) - fig.get_figure().savefig(BytesIO()) diff --git a/tests/visualization_tests/test_timeline.py b/tests/visualization_tests/test_timeline.py index 43d38bd489..bd339e9725 100644 --- a/tests/visualization_tests/test_timeline.py +++ b/tests/visualization_tests/test_timeline.py @@ -4,6 +4,7 @@ from io import BytesIO import time from typing import Any +from typing import Callable import _pytest.capture import pytest @@ -12,23 +13,33 @@ from optuna.samplers._base import _CONSTRAINTS_KEY from optuna.study.study import Study from optuna.trial import TrialState +from optuna.visualization import plot_timeline as plotly_plot_timeline from optuna.visualization._plotly_imports import _imports as plotly_imports +from optuna.visualization._timeline import _get_timeline_info +from optuna.visualization.matplotlib import plot_timeline as plt_plot_timeline +from optuna.visualization.matplotlib._matplotlib_imports import _imports as plt_imports if plotly_imports.is_successful(): from optuna.visualization._plotly_imports import go -from optuna.visualization._timeline import _get_timeline_info -from optuna.visualization._timeline import plot_timeline +if plt_imports.is_successful(): + from optuna.visualization.matplotlib._matplotlib_imports import plt + + +parametrize_plot_timeline = pytest.mark.parametrize( + "plot_timeline", + [plotly_plot_timeline, plt_plot_timeline], +) def _create_study( - trial_states_list: list[TrialState], + trial_states: list[TrialState], trial_sys_attrs: dict[str, Any] | None = None, ) -> Study: study = optuna.create_study() - fmax = float(len(trial_states_list)) - for i, s in enumerate(trial_states_list): + fmax = float(len(trial_states)) + for i, s in enumerate(trial_states): study.add_trial( optuna.trial.create_trial( params={"x": float(i)}, @@ -115,26 +126,36 @@ def test_get_timeline_info_negative_elapsed_time(capsys: _pytest.capture.Capture assert bar.complete < bar.start +@parametrize_plot_timeline @pytest.mark.parametrize( - "trial_states_list", + "trial_states", [ [], [TrialState.COMPLETE, TrialState.PRUNED, TrialState.FAIL, TrialState.RUNNING], [TrialState.RUNNING, TrialState.FAIL, TrialState.PRUNED, TrialState.COMPLETE], ], ) -def test_get_timeline_plot(trial_states_list: list[TrialState]) -> None: - study = _create_study(trial_states_list) - fig = plot_timeline(study) - assert type(fig) is go.Figure - fig.write_image(BytesIO()) +def test_get_timeline_plot( + plot_timeline: Callable[..., Any], trial_states: list[TrialState] +) -> None: + study = _create_study(trial_states) + figure = plot_timeline(study) + + if isinstance(figure, go.Figure): + figure.write_image(BytesIO()) + else: + plt.savefig(BytesIO()) + plt.close() +@parametrize_plot_timeline @pytest.mark.parametrize("waiting_time", [0.0, 1.5]) -def test_get_timeline_plot_with_killed_running_trials(waiting_time: float) -> None: +def test_get_timeline_plot_with_killed_running_trials( + plot_timeline: Callable[..., Any], waiting_time: float +) -> None: def _objective_with_sleep(trial: optuna.Trial) -> float: time.sleep(0.1) - trial.suggest_float("x", -1, 1) + trial.suggest_float("x", -1.0, 1.0) return 1.0 study = optuna.create_study() @@ -148,10 +169,14 @@ def _objective_with_sleep(trial: optuna.Trial) -> float: study.optimize(_objective_with_sleep, n_trials=2) time.sleep(waiting_time) - fig = plot_timeline(study) - bar_colors = [d["marker"]["color"] for d in fig["data"]] - assert "green" in bar_colors, "Running trial, i.e. green color, must be included." - bar_length_in_milliseconds = fig["data"][1]["x"][0] - # If the waiting time is too long, stop the timeline plots for running trials. - assert waiting_time < 1.0 or bar_length_in_milliseconds < waiting_time * 1000 - fig.write_image(BytesIO()) + figure = plot_timeline(study) + + if isinstance(figure, go.Figure): + bar_colors = [d["marker"]["color"] for d in figure["data"]] + assert "green" in bar_colors, "Running trial, i.e. green color, must be included." + bar_length_in_milliseconds = figure["data"][1]["x"][0] + # If the waiting time is too long, stop the timeline plots for running trials. + assert waiting_time < 1.0 or bar_length_in_milliseconds < waiting_time * 1000 + figure.write_image(BytesIO()) + else: + pytest.skip("Matplotlib test is unimplemented.")