Skip to content

Commit

Permalink
Support for simple sequence types (#10)
Browse files Browse the repository at this point in the history
* Fixes #8 - Support for simple sequence types
* Cleaner error messages with unsupported types
* Shift valid origins to a variable
* Avoid poluting globals in tests
* Redundant DefaultIfBool handling can now be removed
* Test globals affecting triggering the need for this have been resolved
  • Loading branch information
inno authored May 26, 2024
1 parent f90e10c commit 7475df8
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 22 deletions.
70 changes: 48 additions & 22 deletions simplecli/simplecli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_args,
get_origin,
)
from types import GenericAlias

try:
from types import UnionType
Expand Down Expand Up @@ -50,6 +51,7 @@ class UnsupportedType(TypeError):
ValueType = Union[type[DefaultIfBool], type[Empty], bool, float, int, str]
ArgDict = dict[str, ValueType]
ArgList = list[str]
valid_origins = (Union, UnionType, list, set)


class Param(inspect.Parameter):
Expand Down Expand Up @@ -86,16 +88,25 @@ def __init__(self, *argv: Any, **kwargs: Any) -> None:
# Overrides required as these values are generally unused
if not self.description:
self.parse_or_prepend(param_line)
annotation = kwargs["annotation"]
if annotation not in get_args(ValueType):
if get_origin(annotation) not in (Union, UnionType):
if annotation is not Empty:
pretty_annotation = (
annotation
if type(annotation) is type
else annotation.__class__.__name__
)
raise UnsupportedType(kwargs["name"], pretty_annotation)
self.validate_annotation(kwargs["name"], kwargs["annotation"])

def validate_annotation(self, name: str, annotation: object) -> None:
if annotation in get_args(ValueType):
return
if get_origin(annotation) in valid_origins:
return
if annotation is Empty:
return

pretty_annotation = (
annotation
if (
type(annotation) is type
or isinstance(annotation, GenericAlias)
)
else annotation.__class__.__name__
)
raise UnsupportedType(name, pretty_annotation)

def __eq__(self, other: object) -> bool:
if not isinstance(other, Param):
Expand Down Expand Up @@ -154,7 +165,7 @@ def help_name(self) -> str:

@property
def help_type(self) -> str:
if get_origin(self.annotation) in (Union, UnionType):
if get_origin(self.annotation) in valid_origins:
typelist = ", ".join([a.__name__ for a in self.datatypes])
return f"[{typelist}]"
return self.annotation.__name__
Expand Down Expand Up @@ -204,6 +215,9 @@ def datatypes(self) -> list[type]:
return [self.annotation]

def validate(self, value: ValueType) -> bool:
# Recurse for list handling
if isinstance(value, list):
return all(self.validate(v) for v in value)
passed = False
for expected_type in self.datatypes:
if expected_type is type(None):
Expand All @@ -223,7 +237,8 @@ def set_value(self, value: ValueType) -> None:
)
# Handle datatypes
args = get_args(self.annotation)
if args:
origin = get_origin(self.annotation)
if origin in (Union, UnionType):
for type_arg in args:
with (
contextlib.suppress(TypeError),
Expand All @@ -239,6 +254,19 @@ def set_value(self, value: ValueType) -> None:
result = value
self._value = self.annotation(result)

def set_value_as_seq(self, values: ArgList) -> None:
args = get_args(self.annotation)
origin = get_origin(self.annotation)
self._value = []
temp_value = []
for value in values:
if self.validate(value) is False:
raise ValueError(
f"'{self.help_name}' must be of type {self.help_type}"
)
temp_value.append(args[0](value))
self._value = origin(temp_value)


def tokenize_string(string: str) -> Generator[TokenInfo, None, None]:
return generate_tokens(io.StringIO(string).readline)
Expand Down Expand Up @@ -317,21 +345,19 @@ def params_to_kwargs(
missing_params = []
try:
for param in params:
kw_value = kw_args.get(param.name)
if get_origin(param.annotation) in (list, set):
# Consume ALL pos_args if list or set
param.set_value_as_seq(pos_args)
pos_args.clear()
# Positional arguments take precedence
if pos_args:
elif pos_args:
param.set_value(pos_args.pop(0))
elif param.name in kw_args:
if kw_args[param.name] is DefaultIfBool:
# Invert the default value
param.set_value(
True if param.default is Empty else not param.default
)
continue
param.set_value(kw_args[param.name])
elif kw_value:
param.set_value(kw_value)
continue
elif param.required:
missing_params.append(param)
continue
except ValueError as e:
exit(e.args[0])

Expand Down
22 changes: 22 additions & 0 deletions tests/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,25 @@ def test_param_boolean_implied_false():
assert p1.required is False
assert p1.optional is False
assert p1.value is False


def test_param_validate_list():
p1 = Param(name="testparam1", annotation=list[str])
assert p1.validate(["foo", "bar"]) is True
assert p1.required is True
assert p1.optional is False


def test_param_set_value_as_seq_valid():
p1 = Param(name="testparam1", annotation=list[str])
p1.set_value_as_seq([123, "bar"])
assert p1.required is True
assert p1.optional is False


def test_param_set_value_as_seq_invalid():
p1 = Param(name="testparam1", annotation=list[int])
with pytest.raises(ValueError, match="int"):
p1.set_value_as_seq([123, "bar"])
assert p1.required is True
assert p1.optional is False
82 changes: 82 additions & 0 deletions tests/test_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@ def code1(a: int):
def code2(a: int):
pass

code1_name = code1.__globals__["__name__"]
code2_name = code2.__globals__["__name__"]
with pytest.raises(SystemExit, match="only ONE"):
code1.__globals__["__name__"] = "__main__"
code2.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
simplecli.wrap(code2)
code1.__globals__["__name__"] = code1_name
code2.__globals__["__name__"] = code2_name


def test_wrap_help_simple(monkeypatch):
Expand All @@ -40,9 +44,11 @@ def test_wrap_help_simple(monkeypatch):
def code1(this_var: int): # stuff and things
pass

code1_name = code1.__globals__["__name__"]
with pytest.raises(SystemExit) as e:
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
code1.__globals__["__name__"] = code1_name

help_msg = e.value.args[0]
assert "--this-var" in help_msg
Expand All @@ -59,9 +65,11 @@ def code(
):
pass

code_name = code.__globals__["__name__"]
with pytest.raises(SystemExit) as e:
code.__globals__["__name__"] = "__main__"
simplecli.wrap(code)
code.__globals__["__name__"] = code_name

help_msg = e.value.args[0]
assert "--that-var" in help_msg
Expand All @@ -76,9 +84,11 @@ def test_wrap_simple_type_error(monkeypatch):
def code(a: int):
pass

code_name = code.__globals__["__name__"]
with pytest.raises(SystemExit, match="Too many positional"):
code.__globals__["__name__"] = "__main__"
simplecli.wrap(code)
code.__globals__["__name__"] = code_name


def test_wrap_simple_value_error(monkeypatch):
Expand All @@ -87,9 +97,11 @@ def test_wrap_simple_value_error(monkeypatch):
def code(a: int):
pass

code_name = code.__globals__["__name__"]
with pytest.raises(SystemExit, match="must be of type int"):
code.__globals__["__name__"] = "__main__"
simplecli.wrap(code)
code.__globals__["__name__"] = code_name


def test_wrap_version_absent(monkeypatch):
Expand All @@ -98,9 +110,11 @@ def test_wrap_version_absent(monkeypatch):
def code2(this_var: int): # stuff and things
pass

code2_name = code2.__globals__["__name__"]
with pytest.raises(SystemExit) as e:
code2.__globals__["__name__"] = "__main__"
simplecli.wrap(code2)
code2.__globals__["__name__"] = code2_name

help_msg = e.value.args[0]
assert "Description:" not in help_msg
Expand All @@ -115,10 +129,12 @@ def test_wrap_version_exists(monkeypatch):
def code1(this_var: int): # stuff and things
pass

code1_name = code1.__globals__["__name__"]
with pytest.raises(SystemExit) as e:
code1.__globals__["__name__"] = "__main__"
code1.__globals__["__version__"] = "1.2.3"
simplecli.wrap(code1)
code1.__globals__["__name__"] = code1_name

help_msg = e.value.args[0]
assert "super_script version 1.2.3" in help_msg
Expand All @@ -135,9 +151,11 @@ def code2(this_var: int): # stuff and things
this is a description
"""

code2_name = code2.__globals__["__name__"]
with pytest.raises(SystemExit) as e:
code2.__globals__["__name__"] = "__main__"
simplecli.wrap(code2)
code2.__globals__["__name__"] = code2_name

help_msg = e.value.args[0]
assert "Description:" in help_msg
Expand All @@ -156,9 +174,11 @@ def code(
):
pass

code_name = code.__globals__["__name__"]
with pytest.raises(SystemExit) as e:
code.__globals__["__name__"] = "__main__"
simplecli.wrap(code)
code.__globals__["__name__"] = code_name

help_msg = e.value.args[0]
assert "--that-var" in help_msg
Expand All @@ -174,10 +194,12 @@ def test_wrap_simple_positional(capfd, monkeypatch):
def code(a: int, b: int, c: typing.Union[int, float]):
print(a + b + c)

code_name = code.__globals__["__name__"]
code.__globals__["__name__"] = "__main__"
simplecli.wrap(code)
(out, _) = capfd.readouterr()
assert out.strip() == "6.5"
code.__globals__["__name__"] = code_name


def test_wrap_boolean_false(monkeypatch):
Expand Down Expand Up @@ -241,9 +263,11 @@ def test_wrap_no_typehint_(monkeypatch):
def code1(foo):
pass

code1_name = code1.__globals__["__name__"]
with pytest.raises(SystemExit, match="parameters need type hints!"):
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
code1.__globals__["__name__"] = code1_name


def test_wrap_no_typehint_no_arg(monkeypatch):
Expand All @@ -252,9 +276,11 @@ def test_wrap_no_typehint_no_arg(monkeypatch):
def code1(foo):
pass

code1_name = code1.__globals__["__name__"]
with pytest.raises(SystemExit, match="ERROR: All wrapped function "):
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
code1.__globals__["__name__"] = code1_name


def test_wrap_no_typehint_kwarg(monkeypatch):
Expand All @@ -263,9 +289,11 @@ def test_wrap_no_typehint_kwarg(monkeypatch):
def code1(foo):
pass

code1_name = code1.__globals__["__name__"]
with pytest.raises(SystemExit, match="function parameters need type"):
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
code1.__globals__["__name__"] = code1_name


def test_wrap_unsupported_type(monkeypatch):
Expand All @@ -274,6 +302,60 @@ def test_wrap_unsupported_type(monkeypatch):
def code1(foo: [str, int]):
pass

code1_name = code1.__globals__["__name__"]
with pytest.raises(SystemExit, match="UnsupportedType: list"):
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
code1.__globals__["__name__"] = code1_name


def test_wrap_list_of_strings(capfd, monkeypatch):
monkeypatch.setattr(sys, "argv", ["fn", "this", "is", "a", "test"])

def code1(foo: list[str]):
print(" | ".join(foo))

code1_name = code1.__globals__["__name__"]
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
assert capfd.readouterr().out == "this | is | a | test\n"
code1.__globals__["__name__"] = code1_name


def test_wrap_list_of_ints(capfd, monkeypatch):
monkeypatch.setattr(sys, "argv", ["filename", "8", "6", "7", "53", "09"])

def code1(foo: list[int]):
print(sum(foo))

code1_name = code1.__globals__["__name__"]
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
assert capfd.readouterr().out == "83\n"
code1.__globals__["__name__"] = code1_name


def test_wrap_set_of_ints_different(capfd, monkeypatch):
monkeypatch.setattr(sys, "argv", ["filename", "8", "6", "7", "53", "09"])

def code1(foo: set[int]):
print(sum(foo))

code1_name = code1.__globals__["__name__"]
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
assert capfd.readouterr().out == "83\n"
code1.__globals__["__name__"] = code1_name


def test_wrap_set_of_ints_same(capfd, monkeypatch):
monkeypatch.setattr(sys, "argv", ["filename", "8", "8", "8", "1", "08"])

def code1(foo: set[int]):
print(sum(foo))

code1_name = code1.__globals__["__name__"]
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
assert capfd.readouterr().out == "9\n"
code1.__globals__["__name__"] = code1_name

0 comments on commit 7475df8

Please sign in to comment.