Skip to content

Commit

Permalink
Merge pull request optuna#4851 from adjeiv/intermediate_values_plot_c…
Browse files Browse the repository at this point in the history
…onstraints

Support constraints for intermediate values plot
  • Loading branch information
not522 authored Aug 4, 2023
2 parents 6de545a + 0da819e commit 056dd2f
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 6 deletions.
18 changes: 16 additions & 2 deletions optuna/visualization/_intermediate_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from typing import NamedTuple

from optuna.logging import get_logger
from optuna.samplers._base import _CONSTRAINTS_KEY
from optuna.study import Study
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
from optuna.visualization._plotly_imports import _imports

Expand All @@ -17,6 +19,7 @@
class _TrialInfo(NamedTuple):
trial_number: int
sorted_intermediate_values: list[tuple[int, float]]
feasible: bool


class _IntermediatePlotInfo(NamedTuple):
Expand All @@ -27,8 +30,15 @@ def _get_intermediate_plot_info(study: Study) -> _IntermediatePlotInfo:
trials = study.get_trials(
deepcopy=False, states=(TrialState.PRUNED, TrialState.COMPLETE, TrialState.RUNNING)
)

def _satisfies_constraints(trial: FrozenTrial) -> bool:
constraints = trial.system_attrs.get(_CONSTRAINTS_KEY)
return constraints is None or all([x <= 0.0 for x in constraints])

trial_infos = [
_TrialInfo(trial.number, sorted(trial.intermediate_values.items()))
_TrialInfo(
trial.number, sorted(trial.intermediate_values.items()), _satisfies_constraints(trial)
)
for trial in trials
if len(trial.intermediate_values) > 0
]
Expand Down Expand Up @@ -113,12 +123,16 @@ def _get_intermediate_plot(info: _IntermediatePlotInfo) -> "go.Figure":
if len(trial_infos) == 0:
return go.Figure(data=[], layout=layout)

default_marker = {"maxdisplayed": 10}

traces = [
go.Scatter(
x=tuple((x for x, _ in tinfo.sorted_intermediate_values)),
y=tuple((y for _, y in tinfo.sorted_intermediate_values)),
mode="lines+markers",
marker={"maxdisplayed": 10},
marker=default_marker
if tinfo.feasible
else {**default_marker, "color": "#CCCCCC"}, # type: ignore[dict-item]
name="Trial{}".format(tinfo.trial_number),
)
for tinfo in trial_infos
Expand Down
2 changes: 1 addition & 1 deletion optuna/visualization/matplotlib/_intermediate_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _get_intermediate_plot(info: _IntermediatePlotInfo) -> "Axes":
ax.plot(
tuple((x for x, _ in tinfo.sorted_intermediate_values)),
tuple((y for _, y in tinfo.sorted_intermediate_values)),
color=cmap(i),
color=cmap(i) if tinfo.feasible else "#CCCCCC",
marker=".",
alpha=0.7,
label="Trial{}".format(tinfo.trial_number),
Expand Down
52 changes: 49 additions & 3 deletions tests/visualization_tests/test_intermediate_plot.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from io import BytesIO
from typing import Any
from typing import Callable
from typing import Sequence

import pytest

from optuna.study import create_study
from optuna.testing.objectives import fail_objective
from optuna.trial import FrozenTrial
from optuna.trial import Trial
import optuna.visualization._intermediate_values
from optuna.visualization._intermediate_values import _get_intermediate_plot_info
Expand Down Expand Up @@ -33,7 +35,11 @@ def objective(trial: Trial, report_intermediate_values: bool) -> float:
study.optimize(lambda t: objective(t, True), n_trials=1)

assert _get_intermediate_plot_info(study) == _IntermediatePlotInfo(
trial_infos=[_TrialInfo(trial_number=0, sorted_intermediate_values=[(0, 1.0), (1, 2.0)])]
trial_infos=[
_TrialInfo(
trial_number=0, sorted_intermediate_values=[(0, 1.0), (1, 2.0)], feasible=True
)
]
)

# Test a study with one trial with intermediate values and
Expand All @@ -42,7 +48,11 @@ def objective(trial: Trial, report_intermediate_values: bool) -> float:
study.optimize(lambda t: objective(t, False), n_trials=1)

assert _get_intermediate_plot_info(study) == _IntermediatePlotInfo(
trial_infos=[_TrialInfo(trial_number=0, sorted_intermediate_values=[(0, 1.0), (1, 2.0)])]
trial_infos=[
_TrialInfo(
trial_number=0, sorted_intermediate_values=[(0, 1.0), (1, 2.0)], feasible=True
)
]
)

# Test a study of only one trial that has no intermediate values.
Expand All @@ -55,6 +65,30 @@ def objective(trial: Trial, report_intermediate_values: bool) -> float:
study.optimize(fail_objective, n_trials=1, catch=(ValueError,))
assert _get_intermediate_plot_info(study) == _IntermediatePlotInfo(trial_infos=[])

# Test a study with constraints
def objective_with_constraints(trial: Trial) -> float:
trial.set_user_attr("constraint", [trial.number % 2])

trial.report(1.0, step=0)
trial.report(2.0, step=1)
return 0.0

def constraints(trial: FrozenTrial) -> Sequence[float]:
return trial.user_attrs["constraint"]

study = create_study(sampler=optuna.samplers.NSGAIIISampler(constraints_func=constraints))
study.optimize(objective_with_constraints, n_trials=2)
assert _get_intermediate_plot_info(study) == _IntermediatePlotInfo(
trial_infos=[
_TrialInfo(
trial_number=0, sorted_intermediate_values=[(0, 1.0), (1, 2.0)], feasible=True
),
_TrialInfo(
trial_number=1, sorted_intermediate_values=[(0, 1.0), (1, 2.0)], feasible=False
),
]
)


@pytest.mark.parametrize(
"plotter",
Expand All @@ -69,7 +103,19 @@ def objective(trial: Trial, report_intermediate_values: bool) -> float:
_IntermediatePlotInfo(trial_infos=[]),
_IntermediatePlotInfo(
trial_infos=[
_TrialInfo(trial_number=0, sorted_intermediate_values=[(0, 1.0), (1, 2.0)])
_TrialInfo(
trial_number=0, sorted_intermediate_values=[(0, 1.0), (1, 2.0)], feasible=True
)
]
),
_IntermediatePlotInfo(
trial_infos=[
_TrialInfo(
trial_number=0, sorted_intermediate_values=[(0, 1.0), (1, 2.0)], feasible=True
),
_TrialInfo(
trial_number=1, sorted_intermediate_values=[(1, 2.0), (0, 1.0)], feasible=False
),
]
),
],
Expand Down

0 comments on commit 056dd2f

Please sign in to comment.