Skip to content

Commit

Permalink
Add decorator to parse function type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
keanpantraw committed Aug 19, 2021
1 parent ed44d8b commit 67d4940
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 0 deletions.
48 changes: 48 additions & 0 deletions docs/guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,54 @@ flag (as in `--obj=True`), or by making sure there's another flag after any
boolean flag argument.


#### Type hints

Fire can be configured to use type hints information by decorating functions with `UseTypeHints()` decorator.
Only `int`, `float` and `str` type hints are respected by default, everything else is ignored (parsed as usual).
Quite common usecase is to instruct fire not to convert strings to integer/floats by supplying `str`
type annotation.

See minimal example below:

```python
import fire

from fire.decorators import UseTypeHints


@UseTypeHints() # () are mandatory here
def main(a: str, b: float):
print(type(a), type(b), type(c), type(d))


if __name__ == "__main__":
fire.Fire(main)
```

When invoked with `python command.py 1 2` this code will produce `str float`.

You can pass custom type hints parsers via decorator argument, following example shows how to parse custom lists:

```python
import fire

from fire.decorators import UseTypeHints


@UseTypeHints({list: lambda arg: [float(x) for x in arg.split(";")]})
def main(a: list, b: str):
print(a)


if __name__ == "__main__":
fire.Fire(main)
```

This code will convert argument `1;2;3;4` argument into `[1.0, 2.0, 3.0, 4.0]` list with floats.
To override default behavior for `int`, `str`, and `float` type hints you need to add them into dictionary supplied to
`UseTypeHints` decorator.


### Using Fire Flags

Fire CLIs all come with a number of flags. These flags should be separated from
Expand Down
40 changes: 40 additions & 0 deletions fire/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,46 @@
ACCEPTS_POSITIONAL_ARGS = 'ACCEPTS_POSITIONAL_ARGS'


def UseTypeHints(type_hints_mapping=None):
"""Instruct fire to use type hints information when parsing args for this
function.
Args:
type_hints_mapping: mapping of type hints into parsing functions, by
default floats, ints and strings are treated, and all other type
hints are ignored (parsed as usual)
Returns:
The decorated function, which now has metadata telling Fire how to perform
according to type hints.
Examples:
@UseTypeHints()
def main(a, b:int, c:float=2.0)
assert isinstance(b, int)
assert isinstance(c, float)
@UseTypeHints({list: lambda s: s.split(";")})
def main(a, c: list):
assert isinstance(c, list)
"""
default_type_hints_mapping = {float: float, int: int, str: str}
if type_hints_mapping is None:
type_hints_mapping = {}
type_hints_mapping.update(default_type_hints_mapping)

def _Decorator(fn):
signature = inspect.signature(fn)
named = {}
for name, param in signature.parameters.items():
has_type_hint = param.annotation is not param.empty
if has_type_hint and param.annotation in type_hints_mapping:
named[name] = type_hints_mapping[param.annotation]
decorator = SetParseFns(**named)
decorated_func = decorator(fn)
return decorated_func
return _Decorator


def SetParseFn(fn, *arguments):
"""Sets the fn for Fire to use to parse args when calling the decorated fn.
Expand Down
32 changes: 32 additions & 0 deletions fire/decorators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import unittest

from fire import core
from fire import decorators
Expand Down Expand Up @@ -90,6 +92,18 @@ def example7(self, arg1, arg2=None, *varargs, **kwargs): # pylint: disable=keyw
return arg1, arg2, varargs, kwargs


if sys.version_info >= (3, 5):
class WithTypeHints(object):

@decorators.UseTypeHints()
def example8(self, a: int, b: str, c, d : float = None):
return a, b, c, d

@decorators.UseTypeHints({list: lambda arg: list(map(int, arg.split(";")))})
def example9(self, a: str, b, c: list, d : list = None):
return a, b, c, d


class FireDecoratorsTest(testutils.BaseTestCase):

def testSetParseFnsNamedArgs(self):
Expand Down Expand Up @@ -169,6 +183,24 @@ def testSetParseFn(self):
command=['example7', '1', '--arg2=2', '3', '4', '--kwarg=5']),
('1', '2', ('3', '4'), {'kwarg': '5'}))

@unittest.skipIf(sys.version_info < (3, 5),
'Type hints were introduced in python 3.5')
def testDefaultTypeHints(self):
self.assertEqual(
core.Fire(WithTypeHints,
command=['example8', '1', '2', '3', '--d=4']),
(1, '2', 3, 4)
)

@unittest.skipIf(sys.version_info < (3, 5),
'Type hints were introduced in python 3.5')
def testCustomTypeHints(self):
self.assertEqual(
core.Fire(WithTypeHints,
command=['example9', '1', '2', '3', '--d=4;5;6']),
('1', 2, [3], [4, 5, 6])
)


if __name__ == '__main__':
testutils.main()

0 comments on commit 67d4940

Please sign in to comment.