Skip to content

Commit

Permalink
PatchGlobal was hiding bugs, remove it
Browse files Browse the repository at this point in the history
Bugs discovered/fixed:
  * Positional arguments completely broken
  * Union types not consistently iterating before erroring out
  * extract_code_params is too complex
  * Validation attempting to apply a None type
  • Loading branch information
inno committed May 5, 2024
1 parent f2796f1 commit ebbe16e
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 75 deletions.
40 changes: 28 additions & 12 deletions simplecli/simplecli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import contextlib
import inspect
import io
Expand Down Expand Up @@ -191,6 +192,8 @@ def datatypes(self) -> list[type]:
def validate(self, value: ValueType) -> bool:
passed = False
for expected_type in self.datatypes:
if expected_type is type(None):
continue
try:
expected_type(value)
return True
Expand All @@ -207,7 +210,10 @@ def set_value(self, value: ValueType) -> None:
args = get_args(self.annotation)
if args:
for type_arg in args:
with contextlib.suppress(TypeError):
with (
contextlib.suppress(TypeError),
contextlib.suppress(ValueError),
):
self._value = type_arg(value)
return
if value is DefaultIfBool:
Expand Down Expand Up @@ -318,6 +324,7 @@ def params_to_kwargs(
if missing_params:
raise TypeError(*missing_params_msg(missing_params))

# If any value from kw_args is not in params, exit with prejiduce!
check_for_unexpected_args(params, kw_args)
return {param.name: param.value for param in params}

Expand Down Expand Up @@ -386,30 +393,40 @@ def code_to_ordered_params(code: Callable[..., Any]) -> OrderedDict:
)


def process_comment(
param: Param | None,
params: list[Param],
token: TokenInfo,
) -> str:
comment = token.string
if params and param is None:
params[-1].parse_or_prepend(token.line, comment)
comment = ""
elif param:
param.parse_or_prepend(token.line, comment, False)
return comment


def extract_code_params(code: Callable[..., Any]) -> list[Param]:
tokens = tokenize_string(inspect.getsource(code))
ordered_params = code_to_ordered_params(code)
hints = {k: v.annotation for k, v in ordered_params.items()}.copy()
comment = ""
param = None
params: list[Param] = []

depth = 0
for token in tokens:
depth += 1
for token in tokenize_string(inspect.getsource(code)):
if token.exact_type is COMMENT:
comment = token.string
if params and param is None:
params[-1].parse_or_prepend(token.line, comment)
comment = ""
elif param:
param.parse_or_prepend(token.line, comment, False)
comment = process_comment(param, params, token)
continue
# tokenize.NL -
# when a logical line of code is continued over multiple lines
if token.exact_type is NL and param:
param.parse_or_prepend(token.line, comment)
elif token.exact_type is NAME and token.string in hints:
if param is not None:
comment = ""
params.append(param)
param = None
hints.pop(token.string)
param = ordered_params.pop(token.string)
if not param.parse_or_prepend(token.line, comment):
Expand All @@ -420,7 +437,6 @@ def extract_code_params(code: Callable[..., Any]) -> list[Param]:
comment = ""
params.append(param)
param = None
# Necessary for < py3.12
if param:
params.append(param)
return params
2 changes: 0 additions & 2 deletions tests/test_help_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ def test_help_text_union():
filename="filename",
params=[Param(name="somevar", annotation=Union[int, float])],
)
print(text)
assert "[int, float]" in text
assert "OPTIONAL" not in text

Expand All @@ -19,7 +18,6 @@ def test_help_text_uniontype():
filename="filename",
params=[Param(name="somevar", annotation=float | int)],
)
print(text)
assert "[float, int]" in text
assert "OPTIONAL" not in text

Expand Down
8 changes: 8 additions & 0 deletions tests/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def test_parse_or_prepend_optional():

def test_parse_or_prepend_union_none():
p1 = Param(name="testparam1", annotation=Union[None, str])
assert p1.internal_only is False
assert p1.optional is True
assert p1.required is False
assert p1.description == ""
Expand All @@ -160,3 +161,10 @@ def test_parse_or_prepend_union_none():
assert p1.optional is True
assert p1.required is False
assert p1.description == "blarg"


def test_internal_only():
p1 = Param(name="testparam1", internal_only=True)
assert p1.internal_only is True
assert p1.required is False
assert p1.optional is False
90 changes: 29 additions & 61 deletions tests/test_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,6 @@ def ensure_wrapped_not_flagged():
simplecli._wrapped = False


class PatchGlobal:
def __init__(
self,
func: typing.Callable,
name: str,
value: str,
) -> None:
self.func = func
self.name = name
self.value = value
self._no_actual_value = False

def __enter__(self) -> "PatchGlobal":
if self.name not in self.func.__globals__:
self._no_actual_value = True
else:
self.actual_value = self.func.__globals__.get(self.name)
self.func.__globals__[self.name] = self.value
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> "PatchGlobal":
if self._no_actual_value:
self.func.__globals__.pop(self.name)
else:
self.func.__globals__[self.name] = self.actual_value
return self


def test_wrap_simple(monkeypatch):
monkeypatch.setattr(sys, "argv", ["filename", "123"])

Expand Down Expand Up @@ -68,10 +40,8 @@ def test_wrap_help_simple(monkeypatch):
def code1(this_var: int): # stuff and things
pass

with (
PatchGlobal(code1, "__name__", "__main__"),
pytest.raises(SystemExit) as e,
):
with pytest.raises(SystemExit) as e:
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)

help_msg = e.value.args[0]
Expand All @@ -90,10 +60,8 @@ def code(
):
pass

with (
PatchGlobal(code, "__name__", "__main__"),
pytest.raises(SystemExit) as e,
):
with pytest.raises(SystemExit) as e:
code.__globals__["__name__"] = "__main__"
simplecli.wrap(code)

help_msg = e.value.args[0]
Expand All @@ -111,10 +79,8 @@ def test_wrap_simple_type_error(monkeypatch):
def code(a: int):
pass

with (
PatchGlobal(code, "__name__", "__main__"),
pytest.raises(SystemExit, match="Too many positional"),
):
with pytest.raises(SystemExit, match="Too many positional"):
code.__globals__["__name__"] = "__main__"
simplecli.wrap(code)


Expand All @@ -124,10 +90,8 @@ def test_wrap_simple_value_error(monkeypatch):
def code(a: int):
pass

with (
PatchGlobal(code, "__name__", "__main__"),
pytest.raises(SystemExit, match="must be of type int"),
):
with pytest.raises(SystemExit, match="must be of type int"):
code.__globals__["__name__"] = "__main__"
simplecli.wrap(code)


Expand All @@ -137,10 +101,8 @@ def test_wrap_version_absent(monkeypatch):
def code2(this_var: int): # stuff and things
pass

with (
PatchGlobal(code2, "__name__", "__main__"),
pytest.raises(SystemExit) as e,
):
with pytest.raises(SystemExit) as e:
code2.__globals__["__name__"] = "__main__"
simplecli.wrap(code2)

help_msg = e.value.args[0]
Expand All @@ -157,11 +119,9 @@ def test_wrap_version_exists(monkeypatch):
def code1(this_var: int): # stuff and things
pass

with (
PatchGlobal(code1, "__version__", "1.2.3"),
PatchGlobal(code1, "__name__", "__main__"),
pytest.raises(SystemExit) as e,
):
with pytest.raises(SystemExit) as e:
code1.__globals__["__name__"] = "__main__"
code1.__globals__["__version__"] = "1.2.3"
simplecli.wrap(code1)

help_msg = e.value.args[0]
Expand All @@ -180,10 +140,8 @@ def code2(this_var: int): # stuff and things
this is a description
"""

with (
PatchGlobal(code2, "__name__", "__main__"),
pytest.raises(SystemExit) as e,
):
with pytest.raises(SystemExit) as e:
code2.__globals__["__name__"] = "__main__"
simplecli.wrap(code2)

help_msg = e.value.args[0]
Expand All @@ -204,10 +162,8 @@ def code(
):
pass

with (
PatchGlobal(code, "__name__", "__main__"),
pytest.raises(SystemExit) as e,
):
with pytest.raises(SystemExit) as e:
code.__globals__["__name__"] = "__main__"
simplecli.wrap(code)

help_msg = e.value.args[0]
Expand All @@ -217,3 +173,15 @@ def code(
assert "--count" in help_msg
assert "Default: 54" in help_msg
assert "OPTIONAL" not in help_msg


def test_wrap_simple_positional(capfd, monkeypatch):
monkeypatch.setattr(sys, "argv", ["filename", "1", "2", "3.5"])

def code(a: int, b: int, c: typing.Union[int, float]):
print(a + b + c)

code.__globals__["__name__"] = "__main__"
simplecli.wrap(code)
(out, _) = capfd.readouterr()
assert out.strip() == "6.5"

0 comments on commit ebbe16e

Please sign in to comment.