Skip to content

Commit

Permalink
Remove the Python version constraint for PyTorch (optuna#5278)
Browse files Browse the repository at this point in the history
* Remove version constraint for pytorch

* Remove import sys
  • Loading branch information
gen740 authored Feb 27, 2024
1 parent e23623c commit ec0d5b5
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 19 deletions.
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ optional = [
"scikit-learn>=0.24.2",
# optuna/visualization/param_importances.py.
"scipy", # optuna/samplers/_gp
# TODO(contramundum53): Remove the constraint after torch supports python 3.12.
"torch; python_version<='3.11'", # optuna/samplers/_gp
"torch", # optuna/samplers/_gp
]
test = [
"coverage",
Expand All @@ -106,8 +105,7 @@ test = [
"moto",
"pytest",
"scipy>=1.9.2; python_version>='3.8'",
# TODO(contramundum53): Remove the constraint after torch supports python 3.12.
"torch; python_version<='3.11'",
"torch",
]

[project.urls]
Expand Down
8 changes: 0 additions & 8 deletions tests/gp_tests/test_acqf.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
from __future__ import annotations

import sys

import numpy as np
import pytest


# TODO(contramundum53): Remove this block after torch supports Python 3.12.
if sys.version_info >= (3, 12):
pytest.skip("PyTorch does not support python 3.12.", allow_module_level=True)

import torch

from optuna._gp.acqf import AcquisitionFunctionType
Expand Down
7 changes: 0 additions & 7 deletions tests/samplers_tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from multiprocessing.managers import DictProxy
import os
import pickle
import sys
from typing import Any
from unittest.mock import patch
import warnings
Expand Down Expand Up @@ -35,8 +34,6 @@
def get_gp_sampler(
*, n_startup_trials: int = 0, seed: int | None = None
) -> optuna.samplers.GPSampler:
if sys.version_info >= (3, 12, 0):
pytest.skip("PyTorch does not support Python 3.12 yet.")
return optuna.samplers.GPSampler(n_startup_trials=n_startup_trials, seed=seed)


Expand Down Expand Up @@ -1024,10 +1021,6 @@ def restore_seed() -> None:
@pytest.mark.slow
@parametrize_sampler_name_with_seed
def test_reproducible_in_other_process(sampler_name: str, unset_seed_in_test: None) -> None:
# TODO(HideakiImamura): Remove the constraint after torch supports python 3.12.
if sys.version_info >= (3, 12, 0) and sampler_name == "GPSampler":
pytest.skip("PyTorch does not support Python 3.12 yet.")

# This test should be tested without `PYTHONHASHSEED`. However, some tool such as tox
# set the environmental variable "PYTHONHASHSEED" by default.
# To do so, this test calls a finalizer: `unset_seed_in_test`.
Expand Down

0 comments on commit ec0d5b5

Please sign in to comment.