diff --git a/simplecli/simplecli.py b/simplecli/simplecli.py index 56d9dc6..3779aa0 100644 --- a/simplecli/simplecli.py +++ b/simplecli/simplecli.py @@ -22,6 +22,7 @@ get_args, get_origin, ) +from types import GenericAlias try: from types import UnionType @@ -50,6 +51,7 @@ class UnsupportedType(TypeError): ValueType = Union[type[DefaultIfBool], type[Empty], bool, float, int, str] ArgDict = dict[str, ValueType] ArgList = list[str] +valid_origins = (Union, UnionType, list, set) class Param(inspect.Parameter): @@ -86,16 +88,25 @@ def __init__(self, *argv: Any, **kwargs: Any) -> None: # Overrides required as these values are generally unused if not self.description: self.parse_or_prepend(param_line) - annotation = kwargs["annotation"] - if annotation not in get_args(ValueType): - if get_origin(annotation) not in (Union, UnionType): - if annotation is not Empty: - pretty_annotation = ( - annotation - if type(annotation) is type - else annotation.__class__.__name__ - ) - raise UnsupportedType(kwargs["name"], pretty_annotation) + self.validate_annotation(kwargs["name"], kwargs["annotation"]) + + def validate_annotation(self, name: str, annotation: object) -> None: + if annotation in get_args(ValueType): + return + if get_origin(annotation) in valid_origins: + return + if annotation is Empty: + return + + pretty_annotation = ( + annotation + if ( + type(annotation) is type + or isinstance(annotation, GenericAlias) + ) + else annotation.__class__.__name__ + ) + raise UnsupportedType(name, pretty_annotation) def __eq__(self, other: object) -> bool: if not isinstance(other, Param): @@ -154,7 +165,7 @@ def help_name(self) -> str: @property def help_type(self) -> str: - if get_origin(self.annotation) in (Union, UnionType): + if get_origin(self.annotation) in valid_origins: typelist = ", ".join([a.__name__ for a in self.datatypes]) return f"[{typelist}]" return self.annotation.__name__ @@ -204,6 +215,9 @@ def datatypes(self) -> list[type]: return [self.annotation] def validate(self, value: ValueType) -> bool: + # Recurse for list handling + if isinstance(value, list): + return all(self.validate(v) for v in value) passed = False for expected_type in self.datatypes: if expected_type is type(None): @@ -223,7 +237,8 @@ def set_value(self, value: ValueType) -> None: ) # Handle datatypes args = get_args(self.annotation) - if args: + origin = get_origin(self.annotation) + if origin in (Union, UnionType): for type_arg in args: with ( contextlib.suppress(TypeError), @@ -239,6 +254,19 @@ def set_value(self, value: ValueType) -> None: result = value self._value = self.annotation(result) + def set_value_as_seq(self, values: ArgList) -> None: + args = get_args(self.annotation) + origin = get_origin(self.annotation) + self._value = [] + temp_value = [] + for value in values: + if self.validate(value) is False: + raise ValueError( + f"'{self.help_name}' must be of type {self.help_type}" + ) + temp_value.append(args[0](value)) + self._value = origin(temp_value) + def tokenize_string(string: str) -> Generator[TokenInfo, None, None]: return generate_tokens(io.StringIO(string).readline) @@ -317,21 +345,19 @@ def params_to_kwargs( missing_params = [] try: for param in params: + kw_value = kw_args.get(param.name) + if get_origin(param.annotation) in (list, set): + # Consume ALL pos_args if list or set + param.set_value_as_seq(pos_args) + pos_args.clear() # Positional arguments take precedence - if pos_args: + elif pos_args: param.set_value(pos_args.pop(0)) - elif param.name in kw_args: - if kw_args[param.name] is DefaultIfBool: - # Invert the default value - param.set_value( - True if param.default is Empty else not param.default - ) - continue - param.set_value(kw_args[param.name]) + elif kw_value: + param.set_value(kw_value) continue elif param.required: missing_params.append(param) - continue except ValueError as e: exit(e.args[0]) diff --git a/tests/test_param.py b/tests/test_param.py index f8d64fd..ce5322f 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -190,3 +190,25 @@ def test_param_boolean_implied_false(): assert p1.required is False assert p1.optional is False assert p1.value is False + + +def test_param_validate_list(): + p1 = Param(name="testparam1", annotation=list[str]) + assert p1.validate(["foo", "bar"]) is True + assert p1.required is True + assert p1.optional is False + + +def test_param_set_value_as_seq_valid(): + p1 = Param(name="testparam1", annotation=list[str]) + p1.set_value_as_seq([123, "bar"]) + assert p1.required is True + assert p1.optional is False + + +def test_param_set_value_as_seq_invalid(): + p1 = Param(name="testparam1", annotation=list[int]) + with pytest.raises(ValueError, match="int"): + p1.set_value_as_seq([123, "bar"]) + assert p1.required is True + assert p1.optional is False diff --git a/tests/test_wrap.py b/tests/test_wrap.py index 78c0620..f5bab6a 100644 --- a/tests/test_wrap.py +++ b/tests/test_wrap.py @@ -27,11 +27,15 @@ def code1(a: int): def code2(a: int): pass + code1_name = code1.__globals__["__name__"] + code2_name = code2.__globals__["__name__"] with pytest.raises(SystemExit, match="only ONE"): code1.__globals__["__name__"] = "__main__" code2.__globals__["__name__"] = "__main__" simplecli.wrap(code1) simplecli.wrap(code2) + code1.__globals__["__name__"] = code1_name + code2.__globals__["__name__"] = code2_name def test_wrap_help_simple(monkeypatch): @@ -40,9 +44,11 @@ def test_wrap_help_simple(monkeypatch): def code1(this_var: int): # stuff and things pass + code1_name = code1.__globals__["__name__"] with pytest.raises(SystemExit) as e: code1.__globals__["__name__"] = "__main__" simplecli.wrap(code1) + code1.__globals__["__name__"] = code1_name help_msg = e.value.args[0] assert "--this-var" in help_msg @@ -59,9 +65,11 @@ def code( ): pass + code_name = code.__globals__["__name__"] with pytest.raises(SystemExit) as e: code.__globals__["__name__"] = "__main__" simplecli.wrap(code) + code.__globals__["__name__"] = code_name help_msg = e.value.args[0] assert "--that-var" in help_msg @@ -76,9 +84,11 @@ def test_wrap_simple_type_error(monkeypatch): def code(a: int): pass + code_name = code.__globals__["__name__"] with pytest.raises(SystemExit, match="Too many positional"): code.__globals__["__name__"] = "__main__" simplecli.wrap(code) + code.__globals__["__name__"] = code_name def test_wrap_simple_value_error(monkeypatch): @@ -87,9 +97,11 @@ def test_wrap_simple_value_error(monkeypatch): def code(a: int): pass + code_name = code.__globals__["__name__"] with pytest.raises(SystemExit, match="must be of type int"): code.__globals__["__name__"] = "__main__" simplecli.wrap(code) + code.__globals__["__name__"] = code_name def test_wrap_version_absent(monkeypatch): @@ -98,9 +110,11 @@ def test_wrap_version_absent(monkeypatch): def code2(this_var: int): # stuff and things pass + code2_name = code2.__globals__["__name__"] with pytest.raises(SystemExit) as e: code2.__globals__["__name__"] = "__main__" simplecli.wrap(code2) + code2.__globals__["__name__"] = code2_name help_msg = e.value.args[0] assert "Description:" not in help_msg @@ -115,10 +129,12 @@ def test_wrap_version_exists(monkeypatch): def code1(this_var: int): # stuff and things pass + code1_name = code1.__globals__["__name__"] with pytest.raises(SystemExit) as e: code1.__globals__["__name__"] = "__main__" code1.__globals__["__version__"] = "1.2.3" simplecli.wrap(code1) + code1.__globals__["__name__"] = code1_name help_msg = e.value.args[0] assert "super_script version 1.2.3" in help_msg @@ -135,9 +151,11 @@ def code2(this_var: int): # stuff and things this is a description """ + code2_name = code2.__globals__["__name__"] with pytest.raises(SystemExit) as e: code2.__globals__["__name__"] = "__main__" simplecli.wrap(code2) + code2.__globals__["__name__"] = code2_name help_msg = e.value.args[0] assert "Description:" in help_msg @@ -156,9 +174,11 @@ def code( ): pass + code_name = code.__globals__["__name__"] with pytest.raises(SystemExit) as e: code.__globals__["__name__"] = "__main__" simplecli.wrap(code) + code.__globals__["__name__"] = code_name help_msg = e.value.args[0] assert "--that-var" in help_msg @@ -174,10 +194,12 @@ def test_wrap_simple_positional(capfd, monkeypatch): def code(a: int, b: int, c: typing.Union[int, float]): print(a + b + c) + code_name = code.__globals__["__name__"] code.__globals__["__name__"] = "__main__" simplecli.wrap(code) (out, _) = capfd.readouterr() assert out.strip() == "6.5" + code.__globals__["__name__"] = code_name def test_wrap_boolean_false(monkeypatch): @@ -241,9 +263,11 @@ def test_wrap_no_typehint_(monkeypatch): def code1(foo): pass + code1_name = code1.__globals__["__name__"] with pytest.raises(SystemExit, match="parameters need type hints!"): code1.__globals__["__name__"] = "__main__" simplecli.wrap(code1) + code1.__globals__["__name__"] = code1_name def test_wrap_no_typehint_no_arg(monkeypatch): @@ -252,9 +276,11 @@ def test_wrap_no_typehint_no_arg(monkeypatch): def code1(foo): pass + code1_name = code1.__globals__["__name__"] with pytest.raises(SystemExit, match="ERROR: All wrapped function "): code1.__globals__["__name__"] = "__main__" simplecli.wrap(code1) + code1.__globals__["__name__"] = code1_name def test_wrap_no_typehint_kwarg(monkeypatch): @@ -263,9 +289,11 @@ def test_wrap_no_typehint_kwarg(monkeypatch): def code1(foo): pass + code1_name = code1.__globals__["__name__"] with pytest.raises(SystemExit, match="function parameters need type"): code1.__globals__["__name__"] = "__main__" simplecli.wrap(code1) + code1.__globals__["__name__"] = code1_name def test_wrap_unsupported_type(monkeypatch): @@ -274,6 +302,60 @@ def test_wrap_unsupported_type(monkeypatch): def code1(foo: [str, int]): pass + code1_name = code1.__globals__["__name__"] with pytest.raises(SystemExit, match="UnsupportedType: list"): code1.__globals__["__name__"] = "__main__" simplecli.wrap(code1) + code1.__globals__["__name__"] = code1_name + + +def test_wrap_list_of_strings(capfd, monkeypatch): + monkeypatch.setattr(sys, "argv", ["fn", "this", "is", "a", "test"]) + + def code1(foo: list[str]): + print(" | ".join(foo)) + + code1_name = code1.__globals__["__name__"] + code1.__globals__["__name__"] = "__main__" + simplecli.wrap(code1) + assert capfd.readouterr().out == "this | is | a | test\n" + code1.__globals__["__name__"] = code1_name + + +def test_wrap_list_of_ints(capfd, monkeypatch): + monkeypatch.setattr(sys, "argv", ["filename", "8", "6", "7", "53", "09"]) + + def code1(foo: list[int]): + print(sum(foo)) + + code1_name = code1.__globals__["__name__"] + code1.__globals__["__name__"] = "__main__" + simplecli.wrap(code1) + assert capfd.readouterr().out == "83\n" + code1.__globals__["__name__"] = code1_name + + +def test_wrap_set_of_ints_different(capfd, monkeypatch): + monkeypatch.setattr(sys, "argv", ["filename", "8", "6", "7", "53", "09"]) + + def code1(foo: set[int]): + print(sum(foo)) + + code1_name = code1.__globals__["__name__"] + code1.__globals__["__name__"] = "__main__" + simplecli.wrap(code1) + assert capfd.readouterr().out == "83\n" + code1.__globals__["__name__"] = code1_name + + +def test_wrap_set_of_ints_same(capfd, monkeypatch): + monkeypatch.setattr(sys, "argv", ["filename", "8", "8", "8", "1", "08"]) + + def code1(foo: set[int]): + print(sum(foo)) + + code1_name = code1.__globals__["__name__"] + code1.__globals__["__name__"] = "__main__" + simplecli.wrap(code1) + assert capfd.readouterr().out == "9\n" + code1.__globals__["__name__"] = code1_name