From b47535a349041166e2097f4aecd7b6a876feff24 Mon Sep 17 00:00:00 2001 From: Kento Nozawa Date: Thu, 22 Feb 2024 17:22:00 +0900 Subject: [PATCH 1/6] Rename variables --- optuna/visualization/_timeline.py | 45 ++++++++++---------- optuna/visualization/matplotlib/_timeline.py | 8 ++-- 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/optuna/visualization/_timeline.py b/optuna/visualization/_timeline.py index 28b1c430ef..071e6ab4c8 100644 --- a/optuna/visualization/_timeline.py +++ b/optuna/visualization/_timeline.py @@ -115,34 +115,35 @@ def _is_running_trials_in_study(study: Study, max_run_duration: datetime.timedel def _get_timeline_info(study: Study) -> _TimelineInfo: bars = [] + max_datetime = _get_max_datetime_complete(study) timedelta_for_small_bar = datetime.timedelta(seconds=1) - for t in study.get_trials(deepcopy=False): - date_start = t.datetime_start or max_datetime - date_complete = ( + for trial in study.get_trials(deepcopy=False): + datetime_start = trial.datetime_start or max_datetime + datetime_complete = ( max_datetime + timedelta_for_small_bar - if t.state == TrialState.RUNNING - else t.datetime_complete or date_start + timedelta_for_small_bar + if trial.state == TrialState.RUNNING + else trial.datetime_complete or datetime_start + timedelta_for_small_bar ) infeasible = ( False - if _CONSTRAINTS_KEY not in t.system_attrs - else any([x > 0 for x in t.system_attrs[_CONSTRAINTS_KEY]]) + if _CONSTRAINTS_KEY not in trial.system_attrs + else any([x > 0 for x in trial.system_attrs[_CONSTRAINTS_KEY]]) ) - if date_complete < date_start: + if datetime_complete < datetime_start: _logger.warning( ( - f"The start and end times for Trial {t.number} seem to be reversed. " - f"The start time is {date_start} and the end time is {date_complete}." + f"The start and end times for Trial {trial.number} seem to be reversed. " + f"The start time is {datetime_start} and the end time is {datetime_complete}." ) ) bars.append( _TimelineBarInfo( - number=t.number, - start=date_start, - complete=date_complete, - state=t.state, - hovertext=_make_hovertext(t), + number=trial.number, + start=datetime_start, + complete=datetime_complete, + state=trial.state, + hovertext=_make_hovertext(trial), infeasible=infeasible, ) ) @@ -163,15 +164,15 @@ def _get_timeline_plot(info: _TimelineInfo) -> "go.Figure": } fig = go.Figure() - for s in sorted(TrialState, key=lambda x: x.name): - if s.name == "COMPLETE": - infeasible_bars = [b for b in info.bars if b.state == s and b.infeasible] - feasible_bars = [b for b in info.bars if b.state == s and not b.infeasible] + for state in sorted(TrialState, key=lambda x: x.name): + if state.name == "COMPLETE": + infeasible_bars = [b for b in info.bars if b.state == state and b.infeasible] + feasible_bars = [b for b in info.bars if b.state == state and not b.infeasible] _plot_bars(infeasible_bars, "#cccccc", "INFEASIBLE", fig) - _plot_bars(feasible_bars, _cm[s.name], s.name, fig) + _plot_bars(feasible_bars, _cm[state.name], state.name, fig) else: - bars = [b for b in info.bars if b.state == s] - _plot_bars(bars, _cm[s.name], s.name, fig) + bars = [b for b in info.bars if b.state == state] + _plot_bars(bars, _cm[state.name], state.name, fig) fig.update_xaxes(type="date") fig.update_layout( go.Layout( diff --git a/optuna/visualization/matplotlib/_timeline.py b/optuna/visualization/matplotlib/_timeline.py index 459074bd69..5de5c55d97 100644 --- a/optuna/visualization/matplotlib/_timeline.py +++ b/optuna/visualization/matplotlib/_timeline.py @@ -107,11 +107,11 @@ def _get_timeline_plot(info: _TimelineInfo) -> "Axes": fig.tight_layout() assert len(info.bars) > 0 - start_time = min([b.start for b in info.bars]) - complete_time = max([b.complete for b in info.bars]) - margin = (complete_time - start_time) * 0.05 + first_start_time = min([b.start for b in info.bars]) + last_complete_time = max([b.complete for b in info.bars]) + margin = (last_complete_time - first_start_time) * 0.05 - ax.set_xlim(right=complete_time + margin, left=start_time - margin) + ax.set_xlim(right=last_complete_time + margin, left=first_start_time - margin) ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True)) ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%H:%M:%S")) plt.gcf().autofmt_xdate() From c9d9d8909dc325a374c1b441a8b1d42cc0fbb97a Mon Sep 17 00:00:00 2001 From: Kento Nozawa Date: Sat, 24 Feb 2024 19:40:40 +0900 Subject: [PATCH 2/6] Fix typo, comment, and docstring --- optuna/storages/_journal/file.py | 11 ++++++----- optuna/study/_tell.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/optuna/storages/_journal/file.py b/optuna/storages/_journal/file.py index 8c32bf7450..6ccb6cbc5e 100644 --- a/optuna/storages/_journal/file.py +++ b/optuna/storages/_journal/file.py @@ -49,7 +49,7 @@ def acquire(self) -> bool: """Acquire a lock in a blocking way by creating a symbolic link of a file. Returns: - :obj:`True` if it succeeded in creating a symbolic link of `self._lock_target_file`. + :obj:`True` if it succeeded in creating a symbolic link of ``self._lock_target_file``. """ sleep_secs = 0.001 @@ -101,7 +101,7 @@ def acquire(self) -> bool: """Acquire a lock in a blocking way by creating a lock file. Returns: - :obj:`True` if it succeeded in creating a `self._lock_file` + :obj:`True` if it succeeded in creating a ``self._lock_file``. """ sleep_secs = 0.001 @@ -158,13 +158,14 @@ def __init__(self, file_path: str, lock_obj: Optional[JournalFileBaseLock] = Non self._file_path: str = file_path self._lock = lock_obj or JournalFileSymlinkLock(self._file_path) if not os.path.exists(self._file_path): - open(self._file_path, "ab").close() # Create a file if it does not exist + open(self._file_path, "ab").close() # Create a file if it does not exist. self._log_number_offset: Dict[int, int] = {0: 0} def read_logs(self, log_number_from: int) -> List[Dict[str, Any]]: logs = [] with open(self._file_path, "rb") as f: - # Maintain remaining_log_size to allow writing by another process while reading the log + # Maintain remaining_log_size to allow writing by another process + # while reading the log. remaining_log_size = os.stat(self._file_path).st_size log_number_start = 0 if log_number_from in self._log_number_offset: @@ -187,7 +188,7 @@ def read_logs(self, log_number_from: int) -> List[Dict[str, Any]]: if log_number < log_number_from: continue - # Ensure that each line ends with line separators (\n, \r\n) + # Ensure that each line ends with line separators (\n, \r\n). if not line.endswith(b"\n"): last_decode_error = ValueError("Invalid log format.") del self._log_number_offset[log_number + 1] diff --git a/optuna/study/_tell.py b/optuna/study/_tell.py index d9fdadab30..150436f4e3 100644 --- a/optuna/study/_tell.py +++ b/optuna/study/_tell.py @@ -164,7 +164,7 @@ def _tell_with_warning( # Cast values to list of floats. if values is not None: - # values have beed checked to be castable to floats in _check_values_are_feasible. + # values have been checked to be castable to floats in _check_values_are_feasible. values = [float(value) for value in values] # Post-processing and storing the trial. From 11721cc081ec290a5c16c09de57b5be06818a51b Mon Sep 17 00:00:00 2001 From: Hiroki Takizawa Date: Mon, 26 Feb 2024 13:26:43 +0900 Subject: [PATCH 3/6] Add a safety guard to Wilcoxon pruner, and modify the docstring (#5256) * Update _wilcoxon.py * Update test_wilcoxon.py * Update _wilcoxon.py * fix comment * fix mypy * Update optuna/pruners/_wilcoxon.py Co-authored-by: contramundum53 * modified comments and docstring * fix black * fix docstring * change the word technique to workaround --------- Co-authored-by: contramundum53 --- optuna/pruners/_wilcoxon.py | 47 +++++++++++++++++++++------- tests/pruners_tests/test_wilcoxon.py | 35 ++++++++++++++++++--- 2 files changed, 67 insertions(+), 15 deletions(-) diff --git a/optuna/pruners/_wilcoxon.py b/optuna/pruners/_wilcoxon.py index efba052dd7..a6c55f4c8a 100644 --- a/optuna/pruners/_wilcoxon.py +++ b/optuna/pruners/_wilcoxon.py @@ -32,8 +32,8 @@ class WilcoxonPruner(BasePruner): This includes the mean performance of n (e.g., 100) shuffled inputs, the mean performance of k-fold cross validation, etc. There can be "easy" or "hard" inputs (the pruner handles correspondence of - the inputs between different trials), - but it is recommended to shuffle the order of inputs once before the optimization. + the inputs between different trials). + In each trial, it is recommended to shuffle the order in which data is processed. When you use this pruner, you must call `Trial.report(value, step)` function for each step (e.g., input id) with the evaluated value. This is different from other pruners in that the reported value need not converge @@ -58,22 +58,23 @@ def eval_func(param, input_): input_data = np.linspace(-1, 1, 100) - - # It is recommended to shuffle the input data once before optimization. - np.random.shuffle(input_data) + rng = np.random.default_rng() def objective(trial): - s = 0.0 - for i in range(len(input_data)): + # In each trial, it is recommended to shuffle the order in which data is processed. + ordering = rng.permutation(range(len(input_data))) + s = [] + for i in ordering: param = trial.suggest_float("param", -1, 1) loss = eval_func(param, input_data[i]) trial.report(loss, i) - s += loss + s.append(loss) if trial.should_prune(): - raise optuna.TrialPruned() + return sum(s) / len(s) # An advanced workaround (see the note below). + # raise optuna.TrialPruned() - return s / len(input_data) + return sum(s) / len(s) study = optuna.study.create_study( @@ -86,6 +87,16 @@ def objective(trial): This pruner cannot handle ``infinity`` or ``nan`` values. Trials containing those values are never pruned. + .. note:: + As an advanced workaround, if `trial.should_prune()` returns `True`, + you can return an estimation of the final value (e.g., the average of all evaluated values) + instead of `raise optuna.TrialPruned()`. + Some algorithms including `TPESampler` internally split trials into below (good) and above (bad), + and pruned trial will always be classified as above. + However, there are some trials that are slightly worse than the best trial and will be pruned, + but they should be classified as below (e.g., top 10%). + This workaround provides beneficial information about such trials to these algorithms. + Args: p_threshold: The p-value threshold for pruning. This value should be between 0 and 1. @@ -154,6 +165,8 @@ def prune(self, study: "optuna.study.Study", trial: FrozenTrial) -> bool: _, idx1, idx2 = np.intersect1d(steps, best_steps, return_indices=True) if len(idx1) < len(step_values): + # This if-statement is never satisfied if following "average_is_best" safety works, + # because the safety ensures that the best trial always has the all steps. warnings.warn( "WilcoxonPruner finds steps existing in the current trial " "but does not exist in the best trial. " @@ -165,8 +178,20 @@ def prune(self, study: "optuna.study.Study", trial: FrozenTrial) -> bool: if len(diff_values) < self._n_startup_steps: return False - alt = "less" if study.direction == StudyDirection.MAXIMIZE else "greater" + if study.direction == StudyDirection.MAXIMIZE: + alt = "less" + average_is_best = best_trial.value <= sum(step_values) / len(step_values) + else: + alt = "greater" + average_is_best = best_trial.value >= sum(step_values) / len(step_values) # We use zsplit to avoid the problem when all values are zero. p = ss.wilcoxon(diff_values, alternative=alt, zero_method="zsplit").pvalue + + if p < self._p_threshold and average_is_best: + # ss.wilcoxon found the current trial is probably worse than the best trial, + # but the value of the best trial was not better than + # the average of the current trial's intermediate values. + # For safety, WilcoxonPruner concludes not to prune it for now. + return False return p < self._p_threshold diff --git a/tests/pruners_tests/test_wilcoxon.py b/tests/pruners_tests/test_wilcoxon.py index f7f7d5c688..89e17cd826 100644 --- a/tests/pruners_tests/test_wilcoxon.py +++ b/tests/pruners_tests/test_wilcoxon.py @@ -70,19 +70,24 @@ def test_wilcoxon_pruner_normal( @pytest.mark.parametrize( "best_intermediate_values,intermediate_values", [ - ({1: 1}, {1: 1, 2: 2}), # Current trial has more steps than best trial + ({1: 1}, {1: 1, 2: 2}), # Current trial has more steps than the best trial ({1: 1}, {1: float("nan")}), # NaN value ({1: float("nan")}, {1: 1}), # NaN value - ({1: 1}, {1: float("inf")}), # NaN value - ({1: float("inf")}, {1: 1}), # NaN value + ({1: 1}, {1: float("inf")}), # infinite value + ({1: float("inf")}, {1: 1}), # infinite value ], ) +@pytest.mark.parametrize( + "direction", + ("minimize", "maximize"), +) def test_wilcoxon_pruner_warn_bad_best_trial( best_intermediate_values: dict[int, float], intermediate_values: dict[int, float], + direction: str, ) -> None: pruner = optuna.pruners.WilcoxonPruner() - study = optuna.study.create_study(pruner=pruner) + study = optuna.study.create_study(direction=direction, pruner=pruner) # Insert best trial study.add_trial( @@ -95,3 +100,25 @@ def test_wilcoxon_pruner_warn_bad_best_trial( for step, value in intermediate_values.items(): trial.report(value, step) trial.should_prune() + + +def test_wilcoxon_pruner_if_average_is_best_then_not_prune() -> None: + pruner = optuna.pruners.WilcoxonPruner(p_threshold=0.5) + study = optuna.study.create_study(direction="minimize", pruner=pruner) + + best_intermediate_values_value = [0.0 for _ in range(10)] + [8.0 for _ in range(10)] + best_intermediate_values = dict(zip(list(range(20)), best_intermediate_values_value)) + + # Insert best trial + study.add_trial( + optuna.trial.create_trial( + value=4.0, params={}, distributions={}, intermediate_values=best_intermediate_values + ) + ) + trial = study.ask() + intermediate_values = [1.0 for _ in range(10)] + [9.0 for _ in range(10)] + for step, value in enumerate(intermediate_values): + trial.report(value, step) + average = sum(intermediate_values[: step + 1]) / (step + 1) + if average <= 4.0: + assert not trial.should_prune() From da8f74c64a4de947041d05113490a1ba1ad137da Mon Sep 17 00:00:00 2001 From: Kento Nozawa Date: Mon, 26 Feb 2024 14:33:08 +0900 Subject: [PATCH 4/6] Clarify that pruners module does not support multi-objective optimization (#5270) --- docs/source/reference/pruners.rst | 3 +++ .../10_key_features/003_efficient_optimization_algorithms.py | 1 + 2 files changed, 4 insertions(+) diff --git a/docs/source/reference/pruners.rst b/docs/source/reference/pruners.rst index cbd5c7b65a..3d3e3f2a03 100644 --- a/docs/source/reference/pruners.rst +++ b/docs/source/reference/pruners.rst @@ -5,6 +5,9 @@ optuna.pruners The :mod:`~optuna.pruners` module defines a :class:`~optuna.pruners.BasePruner` class characterized by an abstract :meth:`~optuna.pruners.BasePruner.prune` method, which, for a given trial and its associated study, returns a boolean value representing whether the trial should be pruned. This determination is made based on stored intermediate values of the objective function, as previously reported for the trial using :meth:`optuna.trial.Trial.report`. The remaining classes in this module represent child classes, inheriting from :class:`~optuna.pruners.BasePruner`, which implement different pruning strategies. +.. warning:: + Currently :mod:`~optuna.pruners` module is expected to be used only for single-objective optimization. + .. seealso:: :ref:`pruning` tutorial explains the concept of the pruner classes and a minimal example. diff --git a/tutorial/10_key_features/003_efficient_optimization_algorithms.py b/tutorial/10_key_features/003_efficient_optimization_algorithms.py index 0ac12996f1..b6970bd755 100644 --- a/tutorial/10_key_features/003_efficient_optimization_algorithms.py +++ b/tutorial/10_key_features/003_efficient_optimization_algorithms.py @@ -65,6 +65,7 @@ # ------------------ # # ``Pruners`` automatically stop unpromising trials at the early stages of the training (a.k.a., automated early-stopping). +# Currently :mod:`~optuna.pruners` module is expected to be used only for single-objective optimization. # # Optuna provides the following pruning algorithms: # From d11cb5d54c788345518eb07ff9e1eac66eb3ce5c Mon Sep 17 00:00:00 2001 From: Kento Nozawa Date: Mon, 26 Feb 2024 14:34:49 +0900 Subject: [PATCH 5/6] Unify indent size, two (#5271) --- pyproject.toml | 146 ++++++++++++++++++++++++------------------------- 1 file changed, 73 insertions(+), 73 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 31c6be95a8..2050f73b45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,24 +11,24 @@ authors = [ {name = "Takuya Akiba"} ] classifiers = [ - "Development Status :: 5 - Production/Stable", - "Intended Audience :: Science/Research", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3 :: Only", - "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Mathematics", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Software Development", - "Topic :: Software Development :: Libraries", - "Topic :: Software Development :: Libraries :: Python Modules", + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3 :: Only", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", ] requires-python = ">=3.7" dependencies = [ @@ -50,64 +50,64 @@ benchmark = [ "virtualenv" ] checking = [ - "black", - "blackdoc", - "flake8", - "isort", - "mypy", - "mypy_boto3_s3", - "types-PyYAML", - "types-redis", - "types-setuptools", - "types-tqdm", - "typing_extensions>=3.10.0.0", + "black", + "blackdoc", + "flake8", + "isort", + "mypy", + "mypy_boto3_s3", + "types-PyYAML", + "types-redis", + "types-setuptools", + "types-tqdm", + "typing_extensions>=3.10.0.0", ] document = [ - "ase", - "cmaes>=0.10.0", # optuna/samplers/_cmaes.py. - "fvcore", - "lightgbm", - "matplotlib!=3.6.0", - "pandas", - "pillow", - "plotly>=4.9.0", # optuna/visualization. - "scikit-learn", - "sphinx", - "sphinx-copybutton", - "sphinx-gallery", - "sphinx-plotly-directive", - "sphinx_rtd_theme>=1.2.0", - "torch", - "torchvision", + "ase", + "cmaes>=0.10.0", # optuna/samplers/_cmaes.py. + "fvcore", + "lightgbm", + "matplotlib!=3.6.0", + "pandas", + "pillow", + "plotly>=4.9.0", # optuna/visualization. + "scikit-learn", + "sphinx", + "sphinx-copybutton", + "sphinx-gallery", + "sphinx-plotly-directive", + "sphinx_rtd_theme>=1.2.0", + "torch", + "torchvision", ] integration = [ - "scikit-learn>=0.24.2", - "shap", - "tensorflow", + "scikit-learn>=0.24.2", + "shap", + "tensorflow", ] optional = [ - "boto3", # optuna/artifacts/_boto3.py. - "cmaes>=0.10.0", # optuna/samplers/_cmaes.py. - "google-cloud-storage", # optuna/artifacts/_gcs.py. - "matplotlib!=3.6.0", # optuna/visualization/matplotlib. - "pandas", # optuna/study.py. - "plotly>=4.9.0", # optuna/visualization. - "redis", # optuna/storages/redis.py. - "scikit-learn>=0.24.2", - # optuna/visualization/param_importances.py. - "scipy", # optuna/samplers/_gp - # TODO(contramundum53): Remove the constraint after torch supports python 3.12. - "torch; python_version<='3.11'", # optuna/samplers/_gp + "boto3", # optuna/artifacts/_boto3.py. + "cmaes>=0.10.0", # optuna/samplers/_cmaes.py. + "google-cloud-storage", # optuna/artifacts/_gcs.py. + "matplotlib!=3.6.0", # optuna/visualization/matplotlib. + "pandas", # optuna/study.py. + "plotly>=4.9.0", # optuna/visualization. + "redis", # optuna/storages/redis.py. + "scikit-learn>=0.24.2", + # optuna/visualization/param_importances.py. + "scipy", # optuna/samplers/_gp + # TODO(contramundum53): Remove the constraint after torch supports python 3.12. + "torch; python_version<='3.11'", # optuna/samplers/_gp ] test = [ - "coverage", - "fakeredis[lua]", - "kaleido", - "moto", - "pytest", - "scipy>=1.9.2; python_version>='3.8'", - # TODO(contramundum53): Remove the constraint after torch supports python 3.12. - "torch; python_version<='3.11'", + "coverage", + "fakeredis[lua]", + "kaleido", + "moto", + "pytest", + "scipy>=1.9.2; python_version>='3.8'", + # TODO(contramundum53): Remove the constraint after torch supports python 3.12. + "torch; python_version<='3.11'", ] [project.urls] @@ -127,10 +127,10 @@ version = {attr = "optuna.version.__version__"} [tool.setuptools.package-data] "optuna" = [ - "storages/_rdb/alembic.ini", - "storages/_rdb/alembic/*.*", - "storages/_rdb/alembic/versions/*.*", - "py.typed", + "storages/_rdb/alembic.ini", + "storages/_rdb/alembic/*.*", + "storages/_rdb/alembic/versions/*.*", + "py.typed", ] [tool.black] @@ -138,7 +138,7 @@ line-length = 99 target-version = ['py38'] exclude = ''' /( - \.eggs + \.eggs | \.git | \.hg | \.mypy_cache From ec0d5b507ed0a6845fcc6ca7178fba688e5a343c Mon Sep 17 00:00:00 2001 From: Gen <54583542+gen740@users.noreply.github.com> Date: Tue, 27 Feb 2024 14:03:03 +0900 Subject: [PATCH 6/6] Remove the Python version constraint for PyTorch (#5278) * Remove version constraint for pytorch * Remove import sys --- pyproject.toml | 6 ++---- tests/gp_tests/test_acqf.py | 8 -------- tests/samplers_tests/test_samplers.py | 7 ------- 3 files changed, 2 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2050f73b45..2af3f072e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,8 +96,7 @@ optional = [ "scikit-learn>=0.24.2", # optuna/visualization/param_importances.py. "scipy", # optuna/samplers/_gp - # TODO(contramundum53): Remove the constraint after torch supports python 3.12. - "torch; python_version<='3.11'", # optuna/samplers/_gp + "torch", # optuna/samplers/_gp ] test = [ "coverage", @@ -106,8 +105,7 @@ test = [ "moto", "pytest", "scipy>=1.9.2; python_version>='3.8'", - # TODO(contramundum53): Remove the constraint after torch supports python 3.12. - "torch; python_version<='3.11'", + "torch", ] [project.urls] diff --git a/tests/gp_tests/test_acqf.py b/tests/gp_tests/test_acqf.py index ae76644fdc..267dc34577 100644 --- a/tests/gp_tests/test_acqf.py +++ b/tests/gp_tests/test_acqf.py @@ -1,15 +1,7 @@ from __future__ import annotations -import sys - import numpy as np import pytest - - -# TODO(contramundum53): Remove this block after torch supports Python 3.12. -if sys.version_info >= (3, 12): - pytest.skip("PyTorch does not support python 3.12.", allow_module_level=True) - import torch from optuna._gp.acqf import AcquisitionFunctionType diff --git a/tests/samplers_tests/test_samplers.py b/tests/samplers_tests/test_samplers.py index 4ab9a9f5d6..a20272a18d 100644 --- a/tests/samplers_tests/test_samplers.py +++ b/tests/samplers_tests/test_samplers.py @@ -6,7 +6,6 @@ from multiprocessing.managers import DictProxy import os import pickle -import sys from typing import Any from unittest.mock import patch import warnings @@ -35,8 +34,6 @@ def get_gp_sampler( *, n_startup_trials: int = 0, seed: int | None = None ) -> optuna.samplers.GPSampler: - if sys.version_info >= (3, 12, 0): - pytest.skip("PyTorch does not support Python 3.12 yet.") return optuna.samplers.GPSampler(n_startup_trials=n_startup_trials, seed=seed) @@ -1024,10 +1021,6 @@ def restore_seed() -> None: @pytest.mark.slow @parametrize_sampler_name_with_seed def test_reproducible_in_other_process(sampler_name: str, unset_seed_in_test: None) -> None: - # TODO(HideakiImamura): Remove the constraint after torch supports python 3.12. - if sys.version_info >= (3, 12, 0) and sampler_name == "GPSampler": - pytest.skip("PyTorch does not support Python 3.12 yet.") - # This test should be tested without `PYTHONHASHSEED`. However, some tool such as tox # set the environmental variable "PYTHONHASHSEED" by default. # To do so, this test calls a finalizer: `unset_seed_in_test`.