Skip to content

Commit

Permalink
Merge pull request optuna#5044 from nabenabe0928/code-fix/change-args…
Browse files Browse the repository at this point in the history
…-to-kwargs-in-suggest-int

Make positional args to kwargs in suggest_int
  • Loading branch information
HideakiImamura authored Nov 2, 2023
2 parents f32428e + 6a44734 commit 4694831
Show file tree
Hide file tree
Showing 11 changed files with 150 additions and 20 deletions.
44 changes: 32 additions & 12 deletions optuna/_convert_positional_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Callable
from collections.abc import Sequence
from functools import wraps
from inspect import Parameter
from inspect import signature
from typing import Any
from typing import TYPE_CHECKING
Expand All @@ -17,6 +18,21 @@
_T = TypeVar("_T")


def _get_positional_arg_names(func: "Callable[_P, _T]") -> list[str]:
params = signature(func).parameters
positional_arg_names = [
name
for name, p in params.items()
if p.default == Parameter.empty and p.kind == p.POSITIONAL_OR_KEYWORD
]
return positional_arg_names


def _infer_kwargs(previous_positional_arg_names: Sequence[str], *args: Any) -> dict[str, Any]:
inferred_kwargs = {arg_name: val for val, arg_name in zip(args, previous_positional_arg_names)}
return inferred_kwargs


def convert_positional_args(
*,
previous_positional_arg_names: Sequence[str],
Expand All @@ -37,10 +53,13 @@ def converter_decorator(func: "Callable[_P, _T]") -> "Callable[_P, _T]":

@wraps(func)
def converter_wrapper(*args: Any, **kwargs: Any) -> "_T":
if len(args) >= 1:
positional_arg_names = _get_positional_arg_names(func)
inferred_kwargs = _infer_kwargs(previous_positional_arg_names, *args)
if len(inferred_kwargs) > len(positional_arg_names):
expected_kwds = set(inferred_kwargs) - set(positional_arg_names)
warnings.warn(
f"{func.__name__}(): Please give all values as keyword arguments."
" See https://github.com/optuna/optuna/issues/3324 for details.",
f"{func.__name__}() got {expected_kwds} as positional arguments "
"but they were expected to be given as keyword arguments.",
FutureWarning,
stacklevel=warning_stacklevel,
)
Expand All @@ -50,15 +69,16 @@ def converter_wrapper(*args: Any, **kwargs: Any) -> "_T":
f" arguments but {len(args)} were given."
)

for val, arg_name in zip(args, previous_positional_arg_names):
# When specifying a positional argument that is not located at the end of args as
# a keyword argument, raise TypeError as follows by imitating the Python standard
# behavior.
if arg_name in kwargs:
raise TypeError(
f"{func.__name__}() got multiple values for argument '{arg_name}'."
)
kwargs[arg_name] = val
duplicated_kwds = set(kwargs).intersection(inferred_kwargs)
if len(duplicated_kwds):
# When specifying positional arguments that are not located at the end of args as
# keyword arguments, raise TypeError as follows by imitating the Python standard
# behavior
raise TypeError(
f"{func.__name__}() got multiple values for arguments {duplicated_kwds}."
)

kwargs.update(inferred_kwargs)

return func(**kwargs)

Expand Down
7 changes: 6 additions & 1 deletion optuna/multi_objective/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
from typing import Union

from optuna import multi_objective
from optuna._convert_positional_args import convert_positional_args
from optuna._deprecated import deprecated_class
from optuna.distributions import BaseDistribution
from optuna.study._study_direction import StudyDirection
from optuna.trial import FrozenTrial
from optuna.trial import Trial
from optuna.trial import TrialState
from optuna.trial._base import _SUGGEST_INT_POSITIONAL_ARGS


CategoricalChoiceType = Union[None, bool, int, float, str]
Expand Down Expand Up @@ -87,7 +89,10 @@ def suggest_discrete_uniform(self, name: str, low: float, high: float, q: float)

return self._trial.suggest_discrete_uniform(name, low, high, q)

def suggest_int(self, name: str, low: int, high: int, step: int = 1, log: bool = False) -> int:
@convert_positional_args(previous_positional_arg_names=_SUGGEST_INT_POSITIONAL_ARGS)
def suggest_int(
self, name: str, low: int, high: int, *, step: int = 1, log: bool = False
) -> int:
"""Suggest a value for the integer parameter.
Please refer to the documentation of :func:`optuna.trial.Trial.suggest_int`
Expand Down
7 changes: 6 additions & 1 deletion optuna/trial/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from optuna.distributions import CategoricalChoiceType


_SUGGEST_INT_POSITIONAL_ARGS = ["self", "name", "low", "high", "step", "log"]


class BaseTrial(abc.ABC):
"""Base class for trials.
Expand Down Expand Up @@ -45,7 +48,9 @@ def suggest_discrete_uniform(self, name: str, low: float, high: float, q: float)
raise NotImplementedError

@abc.abstractmethod
def suggest_int(self, name: str, low: int, high: int, step: int = 1, log: bool = False) -> int:
def suggest_int(
self, name: str, low: int, high: int, *, step: int = 1, log: bool = False
) -> int:
raise NotImplementedError

@overload
Expand Down
7 changes: 6 additions & 1 deletion optuna/trial/_fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import warnings

from optuna import distributions
from optuna._convert_positional_args import convert_positional_args
from optuna._deprecated import deprecated_func
from optuna.distributions import BaseDistribution
from optuna.distributions import CategoricalChoiceType
from optuna.distributions import CategoricalDistribution
from optuna.distributions import FloatDistribution
from optuna.distributions import IntDistribution
from optuna.trial._base import _SUGGEST_INT_POSITIONAL_ARGS
from optuna.trial._base import BaseTrial


Expand Down Expand Up @@ -89,7 +91,10 @@ def suggest_loguniform(self, name: str, low: float, high: float) -> float:
def suggest_discrete_uniform(self, name: str, low: float, high: float, q: float) -> float:
return self.suggest_float(name, low, high, step=q)

def suggest_int(self, name: str, low: int, high: int, step: int = 1, log: bool = False) -> int:
@convert_positional_args(previous_positional_arg_names=_SUGGEST_INT_POSITIONAL_ARGS)
def suggest_int(
self, name: str, low: int, high: int, *, step: int = 1, log: bool = False
) -> int:
return int(self._suggest(name, IntDistribution(low, high, log=log, step=step)))

@overload
Expand Down
7 changes: 6 additions & 1 deletion optuna/trial/_frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from optuna import distributions
from optuna import logging
from optuna._convert_positional_args import convert_positional_args
from optuna._deprecated import deprecated_func
from optuna._typing import JSONSerializable
from optuna.distributions import _convert_old_distribution_to_new_distribution
Expand All @@ -19,6 +20,7 @@
from optuna.distributions import CategoricalDistribution
from optuna.distributions import FloatDistribution
from optuna.distributions import IntDistribution
from optuna.trial._base import _SUGGEST_INT_POSITIONAL_ARGS
from optuna.trial._base import BaseTrial
from optuna.trial._state import TrialState

Expand Down Expand Up @@ -225,7 +227,10 @@ def suggest_loguniform(self, name: str, low: float, high: float) -> float:
def suggest_discrete_uniform(self, name: str, low: float, high: float, q: float) -> float:
return self.suggest_float(name, low, high, step=q)

def suggest_int(self, name: str, low: int, high: int, step: int = 1, log: bool = False) -> int:
@convert_positional_args(previous_positional_arg_names=_SUGGEST_INT_POSITIONAL_ARGS)
def suggest_int(
self, name: str, low: int, high: int, *, step: int = 1, log: bool = False
) -> int:
return int(self._suggest(name, IntDistribution(low, high, log=log, step=step)))

@overload
Expand Down
7 changes: 6 additions & 1 deletion optuna/trial/_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from optuna import distributions
from optuna import logging
from optuna import pruners
from optuna._convert_positional_args import convert_positional_args
from optuna._deprecated import deprecated_func
from optuna.distributions import BaseDistribution
from optuna.distributions import CategoricalChoiceType
from optuna.distributions import CategoricalDistribution
from optuna.distributions import FloatDistribution
from optuna.distributions import IntDistribution
from optuna.trial import FrozenTrial
from optuna.trial._base import _SUGGEST_INT_POSITIONAL_ARGS
from optuna.trial._base import BaseTrial


Expand Down Expand Up @@ -235,7 +237,10 @@ def suggest_discrete_uniform(self, name: str, low: float, high: float, q: float)

return self.suggest_float(name, low, high, step=q)

def suggest_int(self, name: str, low: int, high: int, step: int = 1, log: bool = False) -> int:
@convert_positional_args(previous_positional_arg_names=_SUGGEST_INT_POSITIONAL_ARGS)
def suggest_int(
self, name: str, low: int, high: int, *, step: int = 1, log: bool = False
) -> int:
"""Suggest a value for the integer parameter.
The value is sampled from the integers in :math:`[\\mathsf{low}, \\mathsf{high}]`.
Expand Down
12 changes: 12 additions & 0 deletions tests/multi_objective_tests/test_trial.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import datetime
from typing import List
from typing import Tuple
Expand Down Expand Up @@ -205,3 +207,13 @@ def create_trial(
else:
# If `t1` isn't COMPLETE, it doesn't dominate others.
assert not t1._dominates(t0, directions)


@pytest.mark.parametrize("positional_args_names", [[], ["step"], ["step", "log"]])
def test_suggest_int_positional_args(positional_args_names: list[str]) -> None:
# If log is specified as positional, step must also be provided as positional.
study = optuna.multi_objective.create_study(["maximize"])
kwargs = dict(step=1, log=False)
args = [kwargs[name] for name in positional_args_names]
# No error should not be raised even if the coding style is old.
study.optimize(lambda trial: [trial.suggest_int("x", -1, 1, *args)], n_trials=1)
21 changes: 20 additions & 1 deletion tests/test_convert_positional_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ def _sample_func(*, a: int, b: int, c: int) -> int:
return a + b + c


class _SimpleClass:
@convert_positional_args(previous_positional_arg_names=["self", "a", "b"])
def simple_method(self, a: int, *, b: int, c: int = 1) -> None:
pass


def test_convert_positional_args_decorator() -> None:
previous_positional_arg_names: List[str] = []
decorator_converter = convert_positional_args(
Expand All @@ -20,6 +26,19 @@ def test_convert_positional_args_decorator() -> None:
assert decorated_func.__name__ == _sample_func.__name__


def test_convert_positional_args_future_warning_for_methods() -> None:
simple_class = _SimpleClass()
with pytest.warns(FutureWarning) as record:
simple_class.simple_method(1, 2, c=3) # type: ignore
simple_class.simple_method(1, b=2, c=3) # No warning.
simple_class.simple_method(a=1, b=2, c=3) # No warning.

assert len(record) == 1
for warn in record.list:
assert isinstance(warn.message, FutureWarning)
assert "simple_method" in str(warn.message)


def test_convert_positional_args_future_warning() -> None:
previous_positional_arg_names: List[str] = ["a", "b"]
decorator_converter = convert_positional_args(
Expand Down Expand Up @@ -105,4 +124,4 @@ def test_convert_positional_args_invalid_positional_args() -> None:

with pytest.raises(TypeError) as record:
decorated_func(1, 3, b=2) # type: ignore
assert str(record.value) == "_sample_func() got multiple values for argument 'b'."
assert str(record.value) == "_sample_func() got multiple values for arguments {'b'}."
15 changes: 15 additions & 0 deletions tests/trial_tests/test_fixed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import annotations

import pytest

from optuna.trial import FixedTrial


Expand All @@ -10,6 +14,17 @@ def test_params() -> None:
assert trial.params == params


@pytest.mark.parametrize("positional_args_names", [[], ["step"], ["step", "log"]])
def test_suggest_int_positional_args(positional_args_names: list[str]) -> None:
# If log is specified as positional, step must also be provided as positional.
params = {"x": 1}
trial = FixedTrial(params)
kwargs = dict(step=1, log=False)
args = [kwargs[name] for name in positional_args_names]
# No error should not be raised even if the coding style is old.
trial.suggest_int("x", -1, 1, *args)


def test_number() -> None:
params = {"x": 1}
trial = FixedTrial(params, 2)
Expand Down
26 changes: 26 additions & 0 deletions tests/trial_tests/test_frozen.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import copy
import datetime
from typing import Any
Expand All @@ -11,6 +13,7 @@
from optuna import create_study
from optuna.distributions import BaseDistribution
from optuna.distributions import FloatDistribution
from optuna.distributions import IntDistribution
from optuna.testing.storages import STORAGE_MODES
from optuna.testing.storages import StorageSupplier
import optuna.trial
Expand Down Expand Up @@ -385,3 +388,26 @@ def test_create_trial_distribution_conversion_noop() -> None:

# Check fixed_distributions doesn't change.
assert trial.distributions == fixed_distributions


@pytest.mark.parametrize("positional_args_names", [[], ["step"], ["step", "log"]])
def test_suggest_int_positional_args(positional_args_names: list[str]) -> None:
# If log is specified as positional, step must also be provided as positional.
trial = FrozenTrial(
number=0,
trial_id=0,
state=TrialState.COMPLETE,
value=0.0,
values=None,
datetime_start=datetime.datetime.now(),
datetime_complete=datetime.datetime.now(),
params={"x": 1},
distributions={"x": IntDistribution(-1, 1)},
user_attrs={},
system_attrs={},
intermediate_values={},
)
kwargs = dict(step=1, log=False)
args = [kwargs[name] for name in positional_args_names]
# No error should not be raised even if the coding style is old.
trial.suggest_int("x", -1, 1, *args)
17 changes: 15 additions & 2 deletions tests/trial_tests/test_trial.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import datetime
import math
from typing import Any
Expand Down Expand Up @@ -128,11 +130,11 @@ def test_check_distribution_suggest_discrete_uniform(storage_mode: str) -> None:
assert len([r for r in record if r.category != FutureWarning]) == 1

with pytest.raises(ValueError):
trial.suggest_int("x", 10, 20, 2)
trial.suggest_int("x", 10, 20, step=2)

trial = Trial(study, study._storage.create_new_trial(study._study_id))
with pytest.raises(ValueError):
trial.suggest_int("x", 10, 20, 2)
trial.suggest_int("x", 10, 20, step=2)


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
Expand Down Expand Up @@ -704,3 +706,14 @@ def test_lazy_trial_system_attrs(storage_mode: str) -> None:
system_attrs = _LazyTrialSystemAttrs(trial._trial_id, storage)
assert set(system_attrs.items()) == {("int", 0), ("str", "A")}
assert set(system_attrs.items()) == {("int", 0), ("str", "A")}


@pytest.mark.parametrize("positional_args_names", [[], ["step"], ["step", "log"]])
def test_suggest_int_positional_args(positional_args_names: list[str]) -> None:
# If log is specified as positional, step must also be provided as positional.
study = optuna.create_study()
trial = study.ask()
kwargs = dict(step=1, log=False)
args = [kwargs[name] for name in positional_args_names]
# No error should not be raised even if the coding style is old.
trial.suggest_int("x", -1, 1, *args)

0 comments on commit 4694831

Please sign in to comment.