Skip to content

Commit

Permalink
Merge branch 'optuna:master' into matplotlib-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielAvdar authored Feb 27, 2024
2 parents add6966 + ec0d5b5 commit 07a0d7b
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 135 deletions.
3 changes: 3 additions & 0 deletions docs/source/reference/pruners.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ optuna.pruners

The :mod:`~optuna.pruners` module defines a :class:`~optuna.pruners.BasePruner` class characterized by an abstract :meth:`~optuna.pruners.BasePruner.prune` method, which, for a given trial and its associated study, returns a boolean value representing whether the trial should be pruned. This determination is made based on stored intermediate values of the objective function, as previously reported for the trial using :meth:`optuna.trial.Trial.report`. The remaining classes in this module represent child classes, inheriting from :class:`~optuna.pruners.BasePruner`, which implement different pruning strategies.

.. warning::
Currently :mod:`~optuna.pruners` module is expected to be used only for single-objective optimization.

.. seealso::
:ref:`pruning` tutorial explains the concept of the pruner classes and a minimal example.

Expand Down
47 changes: 36 additions & 11 deletions optuna/pruners/_wilcoxon.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ class WilcoxonPruner(BasePruner):
This includes the mean performance of n (e.g., 100)
shuffled inputs, the mean performance of k-fold cross validation, etc.
There can be "easy" or "hard" inputs (the pruner handles correspondence of
the inputs between different trials),
but it is recommended to shuffle the order of inputs once before the optimization.
the inputs between different trials).
In each trial, it is recommended to shuffle the order in which data is processed.
When you use this pruner, you must call `Trial.report(value, step)` function for each step (e.g., input id) with
the evaluated value. This is different from other pruners in that the reported value need not converge
Expand All @@ -58,22 +58,23 @@ def eval_func(param, input_):
input_data = np.linspace(-1, 1, 100)
# It is recommended to shuffle the input data once before optimization.
np.random.shuffle(input_data)
rng = np.random.default_rng()
def objective(trial):
s = 0.0
for i in range(len(input_data)):
# In each trial, it is recommended to shuffle the order in which data is processed.
ordering = rng.permutation(range(len(input_data)))
s = []
for i in ordering:
param = trial.suggest_float("param", -1, 1)
loss = eval_func(param, input_data[i])
trial.report(loss, i)
s += loss
s.append(loss)
if trial.should_prune():
raise optuna.TrialPruned()
return sum(s) / len(s) # An advanced workaround (see the note below).
# raise optuna.TrialPruned()
return s / len(input_data)
return sum(s) / len(s)
study = optuna.study.create_study(
Expand All @@ -86,6 +87,16 @@ def objective(trial):
This pruner cannot handle ``infinity`` or ``nan`` values.
Trials containing those values are never pruned.
.. note::
As an advanced workaround, if `trial.should_prune()` returns `True`,
you can return an estimation of the final value (e.g., the average of all evaluated values)
instead of `raise optuna.TrialPruned()`.
Some algorithms including `TPESampler` internally split trials into below (good) and above (bad),
and pruned trial will always be classified as above.
However, there are some trials that are slightly worse than the best trial and will be pruned,
but they should be classified as below (e.g., top 10%).
This workaround provides beneficial information about such trials to these algorithms.
Args:
p_threshold:
The p-value threshold for pruning. This value should be between 0 and 1.
Expand Down Expand Up @@ -154,6 +165,8 @@ def prune(self, study: "optuna.study.Study", trial: FrozenTrial) -> bool:
_, idx1, idx2 = np.intersect1d(steps, best_steps, return_indices=True)

if len(idx1) < len(step_values):
# This if-statement is never satisfied if following "average_is_best" safety works,
# because the safety ensures that the best trial always has the all steps.
warnings.warn(
"WilcoxonPruner finds steps existing in the current trial "
"but does not exist in the best trial. "
Expand All @@ -165,8 +178,20 @@ def prune(self, study: "optuna.study.Study", trial: FrozenTrial) -> bool:
if len(diff_values) < self._n_startup_steps:
return False

alt = "less" if study.direction == StudyDirection.MAXIMIZE else "greater"
if study.direction == StudyDirection.MAXIMIZE:
alt = "less"
average_is_best = best_trial.value <= sum(step_values) / len(step_values)
else:
alt = "greater"
average_is_best = best_trial.value >= sum(step_values) / len(step_values)

# We use zsplit to avoid the problem when all values are zero.
p = ss.wilcoxon(diff_values, alternative=alt, zero_method="zsplit").pvalue

if p < self._p_threshold and average_is_best:
# ss.wilcoxon found the current trial is probably worse than the best trial,
# but the value of the best trial was not better than
# the average of the current trial's intermediate values.
# For safety, WilcoxonPruner concludes not to prune it for now.
return False
return p < self._p_threshold
11 changes: 6 additions & 5 deletions optuna/storages/_journal/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def acquire(self) -> bool:
"""Acquire a lock in a blocking way by creating a symbolic link of a file.
Returns:
:obj:`True` if it succeeded in creating a symbolic link of `self._lock_target_file`.
:obj:`True` if it succeeded in creating a symbolic link of ``self._lock_target_file``.
"""
sleep_secs = 0.001
Expand Down Expand Up @@ -101,7 +101,7 @@ def acquire(self) -> bool:
"""Acquire a lock in a blocking way by creating a lock file.
Returns:
:obj:`True` if it succeeded in creating a `self._lock_file`
:obj:`True` if it succeeded in creating a ``self._lock_file``.
"""
sleep_secs = 0.001
Expand Down Expand Up @@ -158,13 +158,14 @@ def __init__(self, file_path: str, lock_obj: Optional[JournalFileBaseLock] = Non
self._file_path: str = file_path
self._lock = lock_obj or JournalFileSymlinkLock(self._file_path)
if not os.path.exists(self._file_path):
open(self._file_path, "ab").close() # Create a file if it does not exist
open(self._file_path, "ab").close() # Create a file if it does not exist.
self._log_number_offset: Dict[int, int] = {0: 0}

def read_logs(self, log_number_from: int) -> List[Dict[str, Any]]:
logs = []
with open(self._file_path, "rb") as f:
# Maintain remaining_log_size to allow writing by another process while reading the log
# Maintain remaining_log_size to allow writing by another process
# while reading the log.
remaining_log_size = os.stat(self._file_path).st_size
log_number_start = 0
if log_number_from in self._log_number_offset:
Expand All @@ -187,7 +188,7 @@ def read_logs(self, log_number_from: int) -> List[Dict[str, Any]]:
if log_number < log_number_from:
continue

# Ensure that each line ends with line separators (\n, \r\n)
# Ensure that each line ends with line separators (\n, \r\n).
if not line.endswith(b"\n"):
last_decode_error = ValueError("Invalid log format.")
del self._log_number_offset[log_number + 1]
Expand Down
2 changes: 1 addition & 1 deletion optuna/study/_tell.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _tell_with_warning(

# Cast values to list of floats.
if values is not None:
# values have beed checked to be castable to floats in _check_values_are_feasible.
# values have been checked to be castable to floats in _check_values_are_feasible.
values = [float(value) for value in values]

# Post-processing and storing the trial.
Expand Down
45 changes: 23 additions & 22 deletions optuna/visualization/_timeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,34 +115,35 @@ def _is_running_trials_in_study(study: Study, max_run_duration: datetime.timedel

def _get_timeline_info(study: Study) -> _TimelineInfo:
bars = []

max_datetime = _get_max_datetime_complete(study)
timedelta_for_small_bar = datetime.timedelta(seconds=1)
for t in study.get_trials(deepcopy=False):
date_start = t.datetime_start or max_datetime
date_complete = (
for trial in study.get_trials(deepcopy=False):
datetime_start = trial.datetime_start or max_datetime
datetime_complete = (
max_datetime + timedelta_for_small_bar
if t.state == TrialState.RUNNING
else t.datetime_complete or date_start + timedelta_for_small_bar
if trial.state == TrialState.RUNNING
else trial.datetime_complete or datetime_start + timedelta_for_small_bar
)
infeasible = (
False
if _CONSTRAINTS_KEY not in t.system_attrs
else any([x > 0 for x in t.system_attrs[_CONSTRAINTS_KEY]])
if _CONSTRAINTS_KEY not in trial.system_attrs
else any([x > 0 for x in trial.system_attrs[_CONSTRAINTS_KEY]])
)
if date_complete < date_start:
if datetime_complete < datetime_start:
_logger.warning(
(
f"The start and end times for Trial {t.number} seem to be reversed. "
f"The start time is {date_start} and the end time is {date_complete}."
f"The start and end times for Trial {trial.number} seem to be reversed. "
f"The start time is {datetime_start} and the end time is {datetime_complete}."
)
)
bars.append(
_TimelineBarInfo(
number=t.number,
start=date_start,
complete=date_complete,
state=t.state,
hovertext=_make_hovertext(t),
number=trial.number,
start=datetime_start,
complete=datetime_complete,
state=trial.state,
hovertext=_make_hovertext(trial),
infeasible=infeasible,
)
)
Expand All @@ -163,15 +164,15 @@ def _get_timeline_plot(info: _TimelineInfo) -> "go.Figure":
}

fig = go.Figure()
for s in sorted(TrialState, key=lambda x: x.name):
if s.name == "COMPLETE":
infeasible_bars = [b for b in info.bars if b.state == s and b.infeasible]
feasible_bars = [b for b in info.bars if b.state == s and not b.infeasible]
for state in sorted(TrialState, key=lambda x: x.name):
if state.name == "COMPLETE":
infeasible_bars = [b for b in info.bars if b.state == state and b.infeasible]
feasible_bars = [b for b in info.bars if b.state == state and not b.infeasible]
_plot_bars(infeasible_bars, "#cccccc", "INFEASIBLE", fig)
_plot_bars(feasible_bars, _cm[s.name], s.name, fig)
_plot_bars(feasible_bars, _cm[state.name], state.name, fig)
else:
bars = [b for b in info.bars if b.state == s]
_plot_bars(bars, _cm[s.name], s.name, fig)
bars = [b for b in info.bars if b.state == state]
_plot_bars(bars, _cm[state.name], state.name, fig)
fig.update_xaxes(type="date")
fig.update_layout(
go.Layout(
Expand Down
8 changes: 4 additions & 4 deletions optuna/visualization/matplotlib/_timeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ def _get_timeline_plot(info: _TimelineInfo) -> "Axes":
fig.tight_layout()

assert len(info.bars) > 0
start_time = min([b.start for b in info.bars])
complete_time = max([b.complete for b in info.bars])
margin = (complete_time - start_time) * 0.05
first_start_time = min([b.start for b in info.bars])
last_complete_time = max([b.complete for b in info.bars])
margin = (last_complete_time - first_start_time) * 0.05

ax.set_xlim(right=complete_time + margin, left=start_time - margin)
ax.set_xlim(right=last_complete_time + margin, left=first_start_time - margin)
ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))
ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%H:%M:%S"))
plt.gcf().autofmt_xdate()
Expand Down
Loading

0 comments on commit 07a0d7b

Please sign in to comment.