Skip to content

Commit

Permalink
feat: basic argspec guessing from typing hints
Browse files Browse the repository at this point in the history
  • Loading branch information
neithere committed Dec 29, 2023
1 parent aeec608 commit 08f8e16
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 159 deletions.
24 changes: 22 additions & 2 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@ Version 0.31.0

Breaking changes:

- The typing hints introspection feature is automatically enabled for any
command (function) which does **not** have any arguments specified via `@arg`
decorator.

This means that, for example, the following function used to fail and now
it will pass::

def main(count: int):
assert isinstance(count, int)

This may lead to unexpected behaviour in some rare cases.

- A small change in the legacy argument mapping policy `BY_NAME_IF_HAS_DEFAULT`
concerning the order of variadic positional vs. keyword-only arguments.

Expand All @@ -23,8 +35,16 @@ Enhancements:

- Added experimental support for basic typing hints (issue #203)

- The feature is automatically enabled for any command (function) which does
**not** have any arguments specified via `@arg` decorator.
The following hints are currently supported:

- ``str``, ``int``, ``float``, ``bool`` (goes to ``type``);
- ``list`` (affects ``nargs``), ``list[T]`` (first subtype goes into ``type``);
- ``Optional[T]`` AKA ``T | None`` (currently interpreted as
``required=False`` for optional and ``nargs="?"`` for positional
arguments; likely to change in the future as use cases accumulate).

The exact interpretation of the type hints is subject to change in the
upcoming versions of Argh.

- Added `always_flush` argument to `dispatch()` (issue #145)

Expand Down
115 changes: 113 additions & 2 deletions src/argh/assembling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,26 @@
from argparse import OPTIONAL, ZERO_OR_MORE, ArgumentParser
from collections import OrderedDict
from enum import Enum
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Tuple,
Union,
get_args,
get_origin,
)

# types.UnionType was introduced in Python < 3.10
try: # pragma: no cover
from types import UnionType

UNION_TYPES = [Union, UnionType]
except ImportError: # pragma: no cover
UNION_TYPES = [Union]

from argh.completion import COMPLETION_ENABLED
from argh.constants import (
Expand Down Expand Up @@ -108,6 +127,7 @@ def func(alpha, beta=1, *, gamma, delta=2): ...
def infer_argspecs_from_function(
function: Callable,
name_mapping_policy: Optional[NameMappingPolicy] = None,
can_use_hints: bool = False,
) -> Iterator[ParserAddArgumentSpec]:
if getattr(function, ATTR_EXPECTS_NAMESPACE_OBJECT, False):
return
Expand Down Expand Up @@ -157,6 +177,17 @@ def _make_cli_arg_names_options(arg_name) -> Tuple[List[str], List[str]]:
else:
default_value = NotDefined

extra_spec_kwargs = {}

if can_use_hints:
hints = function.__annotations__
if parameter.name in hints:
extra_spec_kwargs = (
TypingHintArgSpecGuesser.typing_hint_to_arg_spec_params(
hints[parameter.name]
)
)

if parameter.kind in (
parameter.POSITIONAL_ONLY,
parameter.POSITIONAL_OR_KEYWORD,
Expand Down Expand Up @@ -206,6 +237,7 @@ def _make_cli_arg_names_options(arg_name) -> Tuple[List[str], List[str]]:
func_arg_name=parameter.name,
cli_arg_names=cli_arg_names_positional,
default_value=default_value,
other_add_parser_kwargs=extra_spec_kwargs,
)

if default_value != NotDefined:
Expand All @@ -214,13 +246,22 @@ def _make_cli_arg_names_options(arg_name) -> Tuple[List[str], List[str]]:
else:
arg_spec.nargs = OPTIONAL

# "required" is invalid for positional CLI argument;
# it may have been set from Optional[...] hint above.
# Reinterpret it as "optional positional" instead.
if can_use_hints and "required" in arg_spec.other_add_parser_kwargs:
value = arg_spec.other_add_parser_kwargs.pop("required")
if value is False:
arg_spec.nargs = OPTIONAL

yield arg_spec

elif parameter.kind == parameter.KEYWORD_ONLY:
arg_spec = ParserAddArgumentSpec(
func_arg_name=parameter.name,
cli_arg_names=cli_arg_names_positional,
default_value=default_value,
other_add_parser_kwargs=extra_spec_kwargs,
)

if name_mapping_policy == NameMappingPolicy.BY_NAME_IF_HAS_DEFAULT:
Expand All @@ -238,6 +279,7 @@ def _make_cli_arg_names_options(arg_name) -> Tuple[List[str], List[str]]:
func_arg_name=parameter.name,
cli_arg_names=[parameter.name.replace("_", "-")],
nargs=ZERO_OR_MORE,
other_add_parser_kwargs=extra_spec_kwargs,
)


Expand Down Expand Up @@ -347,8 +389,16 @@ def set_default_command(
has_varkw = any(p.kind == p.VAR_KEYWORD for p in func_signature.parameters.values())

declared_args: List[ParserAddArgumentSpec] = getattr(function, ATTR_ARGS, [])

# transitional period: hints are used for types etc. only if @arg is not used
can_use_hints = not declared_args

inferred_args: List[ParserAddArgumentSpec] = list(
infer_argspecs_from_function(function, name_mapping_policy=name_mapping_policy)
infer_argspecs_from_function(
function,
name_mapping_policy=name_mapping_policy,
can_use_hints=can_use_hints,
)
)

if declared_args and not inferred_args and not has_varkw:
Expand Down Expand Up @@ -662,3 +712,64 @@ def add_subcommands(

class ArgumentNameMappingError(AssemblingError):
...


class TypingHintArgSpecGuesser:
BASIC_TYPES = (str, int, float, bool)

@classmethod
def typing_hint_to_arg_spec_params(
cls, type_def: type, is_positional: bool = False
) -> Dict[str, Any]:
origin = get_origin(type_def)
args = get_args(type_def)

# if not origin and not args and type_def in BASIC_TYPES:
if type_def in cls.BASIC_TYPES:
# `str`
return {
"type": type_def
# "type": _parse_basic_type(type_def)
}

if type_def == list:
# `list`
return {"nargs": "*"}

if any(origin is t for t in UNION_TYPES):
# `str | int`

retval = {}
first_subtype = args[0]
if first_subtype in cls.BASIC_TYPES:
retval["type"] = first_subtype

if first_subtype == list:
retval["nargs"] = "*"

if get_origin(first_subtype) == list:
retval["nargs"] = "*"
item_type = cls._extract_item_type_from_list_type(first_subtype)
if item_type:
retval["type"] = item_type

if type(None) in args:
retval["required"] = False
return retval

if origin == list:
# `list[str]`
retval = {}
retval["nargs"] = "*"
if args[0] in cls.BASIC_TYPES:
retval["type"] = args[0]
return retval

return {}

@classmethod
def _extract_item_type_from_list_type(cls, type_def) -> Optional[type]:
args = get_args(type_def)
if args[0] in cls.BASIC_TYPES:
return args[0]
return None
73 changes: 73 additions & 0 deletions tests/test_assembling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
"""
import argparse
from typing import Optional
from unittest.mock import MagicMock, call, patch

import pytest
Expand Down Expand Up @@ -762,3 +763,75 @@ def test_is_positional():
# this spec is invalid but validation is out of scope of the function
# as it only checks if the first argument has the leading dash
assert argh.assembling._is_positional(["-f", "foo"]) is False


def test_typing_hints_only_used_when_arg_deco_not_used():
@argh.arg("foo", type=int)
def func_decorated(foo: Optional[float]):
...

def func_undecorated(bar: Optional[float]):
...

parser = argparse.ArgumentParser()
parser.add_argument = MagicMock()
argh.set_default_command(parser, func_decorated)
assert parser.add_argument.mock_calls == [
call("foo", type=int, help=argh.constants.DEFAULT_ARGUMENT_TEMPLATE),
]

parser = argparse.ArgumentParser()
parser.add_argument = MagicMock()
argh.set_default_command(parser, func_undecorated)
assert parser.add_argument.mock_calls == [
call(
"bar",
nargs="?",
type=float,
help=argh.constants.DEFAULT_ARGUMENT_TEMPLATE,
),
]


def test_typing_hints():
def func(
alpha,
beta: str,
gamma: Optional[int] = None,
*,
delta: float = 1.5,
epsilon: Optional[int] = 42,
) -> str:
return f"alpha={alpha}, beta={beta}, gamma={gamma}, delta={delta}, epsilon={epsilon}"

parser = argparse.ArgumentParser()
parser.add_argument = MagicMock()
argh.set_default_command(
parser, func, name_mapping_policy=NameMappingPolicy.BY_NAME_IF_KWONLY
)
assert parser.add_argument.mock_calls == [
call("alpha", help=argh.constants.DEFAULT_ARGUMENT_TEMPLATE),
call("beta", type=str, help=argh.constants.DEFAULT_ARGUMENT_TEMPLATE),
call(
"gamma",
default=None,
nargs="?",
type=int,
help=argh.constants.DEFAULT_ARGUMENT_TEMPLATE,
),
call(
"-d",
"--delta",
type=float,
default=1.5,
help=argh.constants.DEFAULT_ARGUMENT_TEMPLATE,
),
call(
"-e",
"--epsilon",
type=int,
default=42,
required=False,
help=argh.constants.DEFAULT_ARGUMENT_TEMPLATE,
),
]
Loading

0 comments on commit 08f8e16

Please sign in to comment.