diff --git a/CHANGES.rst b/CHANGES.rst index bbea9c4..ec7be09 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,6 +7,18 @@ Version 0.31.0 Breaking changes: +- The typing hints introspection feature is automatically enabled for any + command (function) which does **not** have any arguments specified via `@arg` + decorator. + + This means that, for example, the following function used to fail and now + it will pass:: + + def main(count: int): + assert isinstance(count, int) + + This may lead to unexpected behaviour in some rare cases. + - A small change in the legacy argument mapping policy `BY_NAME_IF_HAS_DEFAULT` concerning the order of variadic positional vs. keyword-only arguments. @@ -23,8 +35,16 @@ Enhancements: - Added experimental support for basic typing hints (issue #203) - - The feature is automatically enabled for any command (function) which does - **not** have any arguments specified via `@arg` decorator. + The following hints are currently supported: + + - ``str``, ``int``, ``float``, ``bool`` (goes to ``type``); + - ``list`` (affects ``nargs``), ``list[T]`` (first subtype goes into ``type``); + - ``Optional[T]`` AKA ``T | None`` (currently interpreted as + ``required=False`` for optional and ``nargs="?"`` for positional + arguments; likely to change in the future as use cases accumulate). + + The exact interpretation of the type hints is subject to change in the + upcoming versions of Argh. - Added `always_flush` argument to `dispatch()` (issue #145) diff --git a/src/argh/assembling.py b/src/argh/assembling.py index bdf5427..2222d7b 100644 --- a/src/argh/assembling.py +++ b/src/argh/assembling.py @@ -19,7 +19,26 @@ from argparse import OPTIONAL, ZERO_OR_MORE, ArgumentParser from collections import OrderedDict from enum import Enum -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Optional, + Tuple, + Union, + get_args, + get_origin, +) + +# types.UnionType was introduced in Python < 3.10 +try: # pragma: no cover + from types import UnionType + + UNION_TYPES = [Union, UnionType] +except ImportError: # pragma: no cover + UNION_TYPES = [Union] from argh.completion import COMPLETION_ENABLED from argh.constants import ( @@ -108,6 +127,7 @@ def func(alpha, beta=1, *, gamma, delta=2): ... def infer_argspecs_from_function( function: Callable, name_mapping_policy: Optional[NameMappingPolicy] = None, + can_use_hints: bool = False, ) -> Iterator[ParserAddArgumentSpec]: if getattr(function, ATTR_EXPECTS_NAMESPACE_OBJECT, False): return @@ -157,6 +177,17 @@ def _make_cli_arg_names_options(arg_name) -> Tuple[List[str], List[str]]: else: default_value = NotDefined + extra_spec_kwargs = {} + + if can_use_hints: + hints = function.__annotations__ + if parameter.name in hints: + extra_spec_kwargs = ( + TypingHintArgSpecGuesser.typing_hint_to_arg_spec_params( + hints[parameter.name] + ) + ) + if parameter.kind in ( parameter.POSITIONAL_ONLY, parameter.POSITIONAL_OR_KEYWORD, @@ -206,6 +237,7 @@ def _make_cli_arg_names_options(arg_name) -> Tuple[List[str], List[str]]: func_arg_name=parameter.name, cli_arg_names=cli_arg_names_positional, default_value=default_value, + other_add_parser_kwargs=extra_spec_kwargs, ) if default_value != NotDefined: @@ -214,6 +246,14 @@ def _make_cli_arg_names_options(arg_name) -> Tuple[List[str], List[str]]: else: arg_spec.nargs = OPTIONAL + # "required" is invalid for positional CLI argument; + # it may have been set from Optional[...] hint above. + # Reinterpret it as "optional positional" instead. + if can_use_hints and "required" in arg_spec.other_add_parser_kwargs: + value = arg_spec.other_add_parser_kwargs.pop("required") + if value is False: + arg_spec.nargs = OPTIONAL + yield arg_spec elif parameter.kind == parameter.KEYWORD_ONLY: @@ -221,6 +261,7 @@ def _make_cli_arg_names_options(arg_name) -> Tuple[List[str], List[str]]: func_arg_name=parameter.name, cli_arg_names=cli_arg_names_positional, default_value=default_value, + other_add_parser_kwargs=extra_spec_kwargs, ) if name_mapping_policy == NameMappingPolicy.BY_NAME_IF_HAS_DEFAULT: @@ -238,6 +279,7 @@ def _make_cli_arg_names_options(arg_name) -> Tuple[List[str], List[str]]: func_arg_name=parameter.name, cli_arg_names=[parameter.name.replace("_", "-")], nargs=ZERO_OR_MORE, + other_add_parser_kwargs=extra_spec_kwargs, ) @@ -347,8 +389,16 @@ def set_default_command( has_varkw = any(p.kind == p.VAR_KEYWORD for p in func_signature.parameters.values()) declared_args: List[ParserAddArgumentSpec] = getattr(function, ATTR_ARGS, []) + + # transitional period: hints are used for types etc. only if @arg is not used + can_use_hints = not declared_args + inferred_args: List[ParserAddArgumentSpec] = list( - infer_argspecs_from_function(function, name_mapping_policy=name_mapping_policy) + infer_argspecs_from_function( + function, + name_mapping_policy=name_mapping_policy, + can_use_hints=can_use_hints, + ) ) if declared_args and not inferred_args and not has_varkw: @@ -662,3 +712,64 @@ def add_subcommands( class ArgumentNameMappingError(AssemblingError): ... + + +class TypingHintArgSpecGuesser: + BASIC_TYPES = (str, int, float, bool) + + @classmethod + def typing_hint_to_arg_spec_params( + cls, type_def: type, is_positional: bool = False + ) -> Dict[str, Any]: + origin = get_origin(type_def) + args = get_args(type_def) + + # if not origin and not args and type_def in BASIC_TYPES: + if type_def in cls.BASIC_TYPES: + # `str` + return { + "type": type_def + # "type": _parse_basic_type(type_def) + } + + if type_def == list: + # `list` + return {"nargs": "*"} + + if any(origin is t for t in UNION_TYPES): + # `str | int` + + retval = {} + first_subtype = args[0] + if first_subtype in cls.BASIC_TYPES: + retval["type"] = first_subtype + + if first_subtype == list: + retval["nargs"] = "*" + + if get_origin(first_subtype) == list: + retval["nargs"] = "*" + item_type = cls._extract_item_type_from_list_type(first_subtype) + if item_type: + retval["type"] = item_type + + if type(None) in args: + retval["required"] = False + return retval + + if origin == list: + # `list[str]` + retval = {} + retval["nargs"] = "*" + if args[0] in cls.BASIC_TYPES: + retval["type"] = args[0] + return retval + + return {} + + @classmethod + def _extract_item_type_from_list_type(cls, type_def) -> Optional[type]: + args = get_args(type_def) + if args[0] in cls.BASIC_TYPES: + return args[0] + return None diff --git a/tests/test_assembling.py b/tests/test_assembling.py index f4af39e..ef56db4 100644 --- a/tests/test_assembling.py +++ b/tests/test_assembling.py @@ -3,6 +3,7 @@ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ """ import argparse +from typing import Optional from unittest.mock import MagicMock, call, patch import pytest @@ -762,3 +763,75 @@ def test_is_positional(): # this spec is invalid but validation is out of scope of the function # as it only checks if the first argument has the leading dash assert argh.assembling._is_positional(["-f", "foo"]) is False + + +def test_typing_hints_only_used_when_arg_deco_not_used(): + @argh.arg("foo", type=int) + def func_decorated(foo: Optional[float]): + ... + + def func_undecorated(bar: Optional[float]): + ... + + parser = argparse.ArgumentParser() + parser.add_argument = MagicMock() + argh.set_default_command(parser, func_decorated) + assert parser.add_argument.mock_calls == [ + call("foo", type=int, help=argh.constants.DEFAULT_ARGUMENT_TEMPLATE), + ] + + parser = argparse.ArgumentParser() + parser.add_argument = MagicMock() + argh.set_default_command(parser, func_undecorated) + assert parser.add_argument.mock_calls == [ + call( + "bar", + nargs="?", + type=float, + help=argh.constants.DEFAULT_ARGUMENT_TEMPLATE, + ), + ] + + +def test_typing_hints(): + def func( + alpha, + beta: str, + gamma: Optional[int] = None, + *, + delta: float = 1.5, + epsilon: Optional[int] = 42, + ) -> str: + return f"alpha={alpha}, beta={beta}, gamma={gamma}, delta={delta}, epsilon={epsilon}" + + parser = argparse.ArgumentParser() + parser.add_argument = MagicMock() + argh.set_default_command( + parser, func, name_mapping_policy=NameMappingPolicy.BY_NAME_IF_KWONLY + ) + assert parser.add_argument.mock_calls == [ + call("alpha", help=argh.constants.DEFAULT_ARGUMENT_TEMPLATE), + call("beta", type=str, help=argh.constants.DEFAULT_ARGUMENT_TEMPLATE), + call( + "gamma", + default=None, + nargs="?", + type=int, + help=argh.constants.DEFAULT_ARGUMENT_TEMPLATE, + ), + call( + "-d", + "--delta", + type=float, + default=1.5, + help=argh.constants.DEFAULT_ARGUMENT_TEMPLATE, + ), + call( + "-e", + "--epsilon", + type=int, + default=42, + required=False, + help=argh.constants.DEFAULT_ARGUMENT_TEMPLATE, + ), + ] diff --git a/tests/test_typing_hints.py b/tests/test_typing_hints.py index d8ec695..250041e 100644 --- a/tests/test_typing_hints.py +++ b/tests/test_typing_hints.py @@ -1,176 +1,52 @@ -from types import UnionType -from typing import Any, get_args, get_origin, Union +from typing import List, Optional, Union import pytest +from argh.assembling import TypingHintArgSpecGuesser -BASIC_TYPES = (str, int, float, bool) - -@pytest.mark.parametrize("arg_type", BASIC_TYPES) +@pytest.mark.parametrize("arg_type", TypingHintArgSpecGuesser.BASIC_TYPES) def test_simple_types(arg_type): + guess = TypingHintArgSpecGuesser.typing_hint_to_arg_spec_params + # just the basic type - assert typing_hint_to_arg_spec_params(arg_type) == { - "type": arg_type - } + assert guess(arg_type) == {"type": arg_type} # basic type or None - assert typing_hint_to_arg_spec_params(arg_type | None) == { + assert guess(Optional[arg_type]) == { "type": arg_type, - "required": False - } - assert typing_hint_to_arg_spec_params(None | arg_type) == { -# "type": arg_type, - "required": False + "required": False, } + assert guess(Union[None, arg_type]) == {"required": False} # multiple basic types: the first one is used and None is looked up - assert typing_hint_to_arg_spec_params(arg_type | str | None) == { + assert guess(Union[arg_type, str, None]) == { "type": arg_type, - "required": False + "required": False, } - assert typing_hint_to_arg_spec_params(str | arg_type | None) == { + assert guess(Union[str, arg_type, None]) == { "type": str, - "required": False + "required": False, } def test_list(): - assert typing_hint_to_arg_spec_params(list) == {"nargs": "*"} - assert typing_hint_to_arg_spec_params(list | None) == {"nargs": "*", "required": False} + guess = TypingHintArgSpecGuesser.typing_hint_to_arg_spec_params - assert typing_hint_to_arg_spec_params(list[str]) == {"nargs": "*", "type": str} - assert typing_hint_to_arg_spec_params(list[str] | None) == {"nargs": "*", "type": str, "required": False} + assert guess(list) == {"nargs": "*"} + assert guess(Optional[list]) == {"nargs": "*", "required": False} - assert typing_hint_to_arg_spec_params(list[str, int]) == {"nargs": "*", "type": str} - assert typing_hint_to_arg_spec_params(list[str, int] | None) == { - "nargs": "*", - "type": str, - "required": False - } -# assert typing_hint_to_arg_spec_params(list[str | None]) == { -# "type": str, -# "nargs": "*", -# } -# assert typing_hint_to_arg_spec_params(list[str | None] | None) == { -# "type": str, -# "nargs": "+", -# "required": False -# } + assert guess(List[str]) == {"nargs": "*", "type": str} + assert guess(List[int]) == {"nargs": "*", "type": int} + assert guess(Optional[List[str]]) == {"nargs": "*", "type": str, "required": False} + assert guess(Optional[List[tuple]]) == {"nargs": "*", "required": False} - assert typing_hint_to_arg_spec_params(list[list]) == { - "nargs": "*", - } - assert typing_hint_to_arg_spec_params(list[list, str]) == { - "nargs": "*", - } - assert typing_hint_to_arg_spec_params(list[tuple]) == { - "nargs": "*", - } + assert guess(List[list]) == {"nargs": "*"} + assert guess(List[tuple]) == {"nargs": "*"} @pytest.mark.parametrize("arg_type", (dict, tuple)) def test_unusable_types(arg_type): - assert typing_hint_to_arg_spec_params(arg_type) == {} - - -def typing_hint_to_arg_spec_params(type_def: type) -> dict[str, Any]: - origin = get_origin(type_def) - args = get_args(type_def) - - print("--------------------------------") - print(f"PARSE type_def: {type_def}, origin: {origin}, args: {args}") - - #if not origin and not args and type_def in BASIC_TYPES: - if type_def in BASIC_TYPES: - print("* basic type") - return { - "type": type_def - #"type": _parse_basic_type(type_def) - } - - if type_def == list: - print("* list (no nested types)") - return {"nargs": "*"} - - if origin == UnionType: - print("* union") - #return _parse_union_type(args) - retval = {} - #first_subtype = [t for t in args if not isinstance(None, t)][0] - first_subtype = args[0] - print("first_subtype", first_subtype) - if first_subtype in BASIC_TYPES: - retval["type"] = first_subtype - - if first_subtype == list: - retval["nargs"] = "*" - - if get_origin(first_subtype) == list: - retval["nargs"] = "*" - item_type = _extract_item_type_from_list_type(first_subtype) - print(f"item type {item_type}") - if item_type: - retval["type"] = item_type - - if type(None) in args: - retval["required"] = False - return retval - - if origin == list: - print("* list (with nested types)") - retval = {} - retval["nargs"] = "*" - print(f"item type {args[0]}") - if args[0] in BASIC_TYPES: - retval["type"] = args[0] - return retval - - print("huh??") - return {} - - -def _extract_item_type_from_list_type(type_def) -> type | None: - print("_extract_item_type_from_list_type", type_def) - args = get_args(type_def) - if not args: - return - if args[0] in BASIC_TYPES: - return args[0] - return None - - -# if origin == Union: -# return _parse_union_type(get_args(type_def)) - -# if origin == list: -# return _parse_list_type(get_args(type_def)) - -# parsed_single = _parse_concrete_typ - -# if origin in (str, int, float, bool): -# return origin -# if origin == list: - - -#def _parse_basic_type(type_def: type) -> dict[str, Any]: -# print("parse basic type", type_def) -# return type_def - - -def _parse_union_type(types: list[type]) -> dict[str, Any]: - print("parse union type", types) - return { - "type": [t for t in types if not isinstance(None, t)][0], - "required": type(None) not in types, - } - + guess = TypingHintArgSpecGuesser.typing_hint_to_arg_spec_params -def _parse_list_type(types: list[type]) -> dict[str, Any]: - print("parse list type", types) - if types: - # just take the first item - return { - "type": types[0] - } - return {} + assert guess(arg_type) == {} diff --git a/tests/test_utils.py b/tests/test_utils.py index b591d90..df91f47 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,17 +3,11 @@ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ """ -import functools from argparse import ArgumentParser, _SubParsersAction -from typing import Callable import pytest -from argh.utils import ( - SubparsersNotDefinedError, - get_subparsers, - unindent, -) +from argh.utils import SubparsersNotDefinedError, get_subparsers, unindent def test_util_unindent():