Skip to content

Commit

Permalink
Change variable names for readability
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Oct 31, 2023
1 parent e763c2c commit 6a44734
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions optuna/_convert_positional_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def _get_positional_arg_names(func: "Callable[_P, _T]") -> list[str]:
return positional_arg_names


def _infer_given_args(previous_positional_arg_names: Sequence[str], *args: Any) -> dict[str, Any]:
inferred_args = {arg_name: val for val, arg_name in zip(args, previous_positional_arg_names)}
return inferred_args
def _infer_kwargs(previous_positional_arg_names: Sequence[str], *args: Any) -> dict[str, Any]:
inferred_kwargs = {arg_name: val for val, arg_name in zip(args, previous_positional_arg_names)}
return inferred_kwargs


def convert_positional_args(
Expand All @@ -54,11 +54,11 @@ def converter_decorator(func: "Callable[_P, _T]") -> "Callable[_P, _T]":
@wraps(func)
def converter_wrapper(*args: Any, **kwargs: Any) -> "_T":
positional_arg_names = _get_positional_arg_names(func)
inferred_args = _infer_given_args(previous_positional_arg_names, *args)
if len(inferred_args) > len(positional_arg_names):
kwargs_expected = set(inferred_args) - set(positional_arg_names)
inferred_kwargs = _infer_kwargs(previous_positional_arg_names, *args)
if len(inferred_kwargs) > len(positional_arg_names):
expected_kwds = set(inferred_kwargs) - set(positional_arg_names)
warnings.warn(
f"{func.__name__}() got {kwargs_expected} as positional arguments "
f"{func.__name__}() got {expected_kwds} as positional arguments "
"but they were expected to be given as keyword arguments.",
FutureWarning,
stacklevel=warning_stacklevel,
Expand All @@ -69,16 +69,16 @@ def converter_wrapper(*args: Any, **kwargs: Any) -> "_T":
f" arguments but {len(args)} were given."
)

duplicated_arg_names = set(kwargs).intersection(inferred_args)
if len(duplicated_arg_names):
duplicated_kwds = set(kwargs).intersection(inferred_kwargs)
if len(duplicated_kwds):
# When specifying positional arguments that are not located at the end of args as
# keyword arguments, raise TypeError as follows by imitating the Python standard
# behavior
raise TypeError(
f"{func.__name__}() got multiple values for arguments {duplicated_arg_names}."
f"{func.__name__}() got multiple values for arguments {duplicated_kwds}."
)

kwargs.update(inferred_args)
kwargs.update(inferred_kwargs)

return func(**kwargs)

Expand Down

0 comments on commit 6a44734

Please sign in to comment.