Skip to content

Commit

Permalink
Fixes #8 - Support for simple sequence types
Browse files Browse the repository at this point in the history
  • Loading branch information
inno committed May 26, 2024
1 parent a14d0c5 commit 7486ab5
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 3 deletions.
27 changes: 24 additions & 3 deletions simplecli/simplecli.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class UnsupportedType(TypeError):
ValueType = Union[type[DefaultIfBool], type[Empty], bool, float, int, str]
ArgDict = dict[str, ValueType]
ArgList = list[str]
valid_origins = (Union, UnionType)
valid_origins = (Union, UnionType, list, set)


class Param(inspect.Parameter):
Expand Down Expand Up @@ -215,6 +215,9 @@ def datatypes(self) -> list[type]:
return [self.annotation]

def validate(self, value: ValueType) -> bool:
# Recurse for list handling
if isinstance(value, list):
return all(self.validate(v) for v in value)
passed = False
for expected_type in self.datatypes:
if expected_type is type(None):
Expand All @@ -234,7 +237,8 @@ def set_value(self, value: ValueType) -> None:
)
# Handle datatypes
args = get_args(self.annotation)
if args:
origin = get_origin(self.annotation)
if origin in (Union, UnionType):
for type_arg in args:
with (
contextlib.suppress(TypeError),
Expand All @@ -250,6 +254,19 @@ def set_value(self, value: ValueType) -> None:
result = value
self._value = self.annotation(result)

def set_value_as_seq(self, values: ArgList) -> None:
args = get_args(self.annotation)
origin = get_origin(self.annotation)
self._value = []
temp_value = []
for value in values:
if self.validate(value) is False:
raise ValueError(
f"'{self.help_name}' must be of type {self.help_type}"
)
temp_value.append(args[0](value))
self._value = origin(temp_value)


def tokenize_string(string: str) -> Generator[TokenInfo, None, None]:
return generate_tokens(io.StringIO(string).readline)
Expand Down Expand Up @@ -329,8 +346,12 @@ def params_to_kwargs(
try:
for param in params:
kw_value = kw_args.get(param.name)
if get_origin(param.annotation) in (list, set):
# Consume ALL pos_args if list or set
param.set_value_as_seq(pos_args)
pos_args.clear()
# Positional arguments take precedence
if pos_args:
elif pos_args:
param.set_value(pos_args.pop(0))
elif kw_value:
param.set_value(kw_value)
Expand Down
52 changes: 52 additions & 0 deletions tests/test_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,55 @@ def code1(foo: [str, int]):
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
code1.__globals__["__name__"] = code1_name


def test_wrap_list_of_strings(capfd, monkeypatch):
monkeypatch.setattr(sys, "argv", ["fn", "this", "is", "a", "test"])

def code1(foo: list[str]):
print(" | ".join(foo))

code1_name = code1.__globals__["__name__"]
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
assert capfd.readouterr().out == "this | is | a | test\n"
code1.__globals__["__name__"] = code1_name


def test_wrap_list_of_ints(capfd, monkeypatch):
monkeypatch.setattr(sys, "argv", ["filename", "8", "6", "7", "53", "09"])

def code1(foo: list[int]):
print(sum(foo))

code1_name = code1.__globals__["__name__"]
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
assert capfd.readouterr().out == "83\n"
code1.__globals__["__name__"] = code1_name


def test_wrap_set_of_ints_different(capfd, monkeypatch):
monkeypatch.setattr(sys, "argv", ["filename", "8", "6", "7", "53", "09"])

def code1(foo: set[int]):
print(sum(foo))

code1_name = code1.__globals__["__name__"]
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
assert capfd.readouterr().out == "83\n"
code1.__globals__["__name__"] = code1_name


def test_wrap_set_of_ints_same(capfd, monkeypatch):
monkeypatch.setattr(sys, "argv", ["filename", "8", "8", "8", "1", "08"])

def code1(foo: set[int]):
print(sum(foo))

code1_name = code1.__globals__["__name__"]
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
assert capfd.readouterr().out == "9\n"
code1.__globals__["__name__"] = code1_name

0 comments on commit 7486ab5

Please sign in to comment.