From 7486ab56b9a1776be4edb382265219368ed79593 Mon Sep 17 00:00:00 2001 From: Clif Bratcher Date: Sun, 26 May 2024 11:32:30 -0400 Subject: [PATCH] Fixes #8 - Support for simple sequence types --- simplecli/simplecli.py | 27 +++++++++++++++++++--- tests/test_wrap.py | 52 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/simplecli/simplecli.py b/simplecli/simplecli.py index db42d15..3779aa0 100644 --- a/simplecli/simplecli.py +++ b/simplecli/simplecli.py @@ -51,7 +51,7 @@ class UnsupportedType(TypeError): ValueType = Union[type[DefaultIfBool], type[Empty], bool, float, int, str] ArgDict = dict[str, ValueType] ArgList = list[str] -valid_origins = (Union, UnionType) +valid_origins = (Union, UnionType, list, set) class Param(inspect.Parameter): @@ -215,6 +215,9 @@ def datatypes(self) -> list[type]: return [self.annotation] def validate(self, value: ValueType) -> bool: + # Recurse for list handling + if isinstance(value, list): + return all(self.validate(v) for v in value) passed = False for expected_type in self.datatypes: if expected_type is type(None): @@ -234,7 +237,8 @@ def set_value(self, value: ValueType) -> None: ) # Handle datatypes args = get_args(self.annotation) - if args: + origin = get_origin(self.annotation) + if origin in (Union, UnionType): for type_arg in args: with ( contextlib.suppress(TypeError), @@ -250,6 +254,19 @@ def set_value(self, value: ValueType) -> None: result = value self._value = self.annotation(result) + def set_value_as_seq(self, values: ArgList) -> None: + args = get_args(self.annotation) + origin = get_origin(self.annotation) + self._value = [] + temp_value = [] + for value in values: + if self.validate(value) is False: + raise ValueError( + f"'{self.help_name}' must be of type {self.help_type}" + ) + temp_value.append(args[0](value)) + self._value = origin(temp_value) + def tokenize_string(string: str) -> Generator[TokenInfo, None, None]: return generate_tokens(io.StringIO(string).readline) @@ -329,8 +346,12 @@ def params_to_kwargs( try: for param in params: kw_value = kw_args.get(param.name) + if get_origin(param.annotation) in (list, set): + # Consume ALL pos_args if list or set + param.set_value_as_seq(pos_args) + pos_args.clear() # Positional arguments take precedence - if pos_args: + elif pos_args: param.set_value(pos_args.pop(0)) elif kw_value: param.set_value(kw_value) diff --git a/tests/test_wrap.py b/tests/test_wrap.py index 783801f..f5bab6a 100644 --- a/tests/test_wrap.py +++ b/tests/test_wrap.py @@ -307,3 +307,55 @@ def code1(foo: [str, int]): code1.__globals__["__name__"] = "__main__" simplecli.wrap(code1) code1.__globals__["__name__"] = code1_name + + +def test_wrap_list_of_strings(capfd, monkeypatch): + monkeypatch.setattr(sys, "argv", ["fn", "this", "is", "a", "test"]) + + def code1(foo: list[str]): + print(" | ".join(foo)) + + code1_name = code1.__globals__["__name__"] + code1.__globals__["__name__"] = "__main__" + simplecli.wrap(code1) + assert capfd.readouterr().out == "this | is | a | test\n" + code1.__globals__["__name__"] = code1_name + + +def test_wrap_list_of_ints(capfd, monkeypatch): + monkeypatch.setattr(sys, "argv", ["filename", "8", "6", "7", "53", "09"]) + + def code1(foo: list[int]): + print(sum(foo)) + + code1_name = code1.__globals__["__name__"] + code1.__globals__["__name__"] = "__main__" + simplecli.wrap(code1) + assert capfd.readouterr().out == "83\n" + code1.__globals__["__name__"] = code1_name + + +def test_wrap_set_of_ints_different(capfd, monkeypatch): + monkeypatch.setattr(sys, "argv", ["filename", "8", "6", "7", "53", "09"]) + + def code1(foo: set[int]): + print(sum(foo)) + + code1_name = code1.__globals__["__name__"] + code1.__globals__["__name__"] = "__main__" + simplecli.wrap(code1) + assert capfd.readouterr().out == "83\n" + code1.__globals__["__name__"] = code1_name + + +def test_wrap_set_of_ints_same(capfd, monkeypatch): + monkeypatch.setattr(sys, "argv", ["filename", "8", "8", "8", "1", "08"]) + + def code1(foo: set[int]): + print(sum(foo)) + + code1_name = code1.__globals__["__name__"] + code1.__globals__["__name__"] = "__main__" + simplecli.wrap(code1) + assert capfd.readouterr().out == "9\n" + code1.__globals__["__name__"] = code1_name