Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #11 - Correctly handle boolean flags #12

Merged
merged 1 commit into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions simplecli/simplecli.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,7 @@ def set_value(self, value: ValueType) -> None:
if value is DefaultIfBool:
if bool not in self.datatypes:
raise ValueError(f"'{self.help_name}' requires a value")
result = self.default if self.default is not Empty else True
elif bool in self.datatypes and self.default is Empty:
result = value
result = True if self.default is Empty else not bool(self.default)
self._value = self.annotation(result)

def set_value_as_seq(self, values: ArgList) -> None:
Expand Down
26 changes: 25 additions & 1 deletion tests/test_extract_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,14 @@ def code(foo: bool = False):
pass

params = simplecli.extract_code_params(code)
assert params == [
simplecli.Param(
name="foo",
annotation=bool,
default=False,
value=False,
),
]
params[0].set_value(True)
assert params == [
simplecli.Param(
Expand All @@ -345,6 +353,14 @@ def code(foo: bool):
pass

params = simplecli.extract_code_params(code)
assert params == [
simplecli.Param(
name="foo",
annotation=bool,
default=simplecli.Empty,
value=False,
),
]
params[0].set_value(simplecli.DefaultIfBool)
assert params == [
simplecli.Param(
Expand All @@ -361,14 +377,22 @@ def code(foo: bool = False):
pass

params = simplecli.extract_code_params(code)
params[0].set_value(simplecli.DefaultIfBool)
assert params == [
simplecli.Param(
name="foo",
annotation=bool,
default=False,
value=False,
),
]
params[0].set_value(simplecli.DefaultIfBool)
assert params == [
simplecli.Param(
name="foo",
annotation=bool,
default=False,
value=True,
),
]


Expand Down
115 changes: 37 additions & 78 deletions tests/test_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ def ensure_wrapped_not_flagged():
simplecli._wrapped = False


def simplecli_wrap_main(code):
code_name = code.__globals__["__name__"]
code.__globals__["__name__"] = "__main__"
simplecli.wrap(code)
code.__globals__["__name__"] = code_name


def test_wrap_simple(monkeypatch):
monkeypatch.setattr(sys, "argv", ["filename", "123"])

Expand All @@ -27,15 +34,9 @@ def code1(a: int):
def code2(a: int):
pass

code1_name = code1.__globals__["__name__"]
code2_name = code2.__globals__["__name__"]
with pytest.raises(SystemExit, match="only ONE"):
code1.__globals__["__name__"] = "__main__"
code2.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
simplecli.wrap(code2)
code1.__globals__["__name__"] = code1_name
code2.__globals__["__name__"] = code2_name
simplecli_wrap_main(code1)
simplecli_wrap_main(code2)


def test_wrap_help_simple(monkeypatch):
Expand All @@ -44,11 +45,8 @@ def test_wrap_help_simple(monkeypatch):
def code1(this_var: int): # stuff and things
pass

code1_name = code1.__globals__["__name__"]
with pytest.raises(SystemExit) as e:
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
code1.__globals__["__name__"] = code1_name
simplecli_wrap_main(code1)

help_msg = e.value.args[0]
assert "--this-var" in help_msg
Expand All @@ -65,11 +63,8 @@ def code(
):
pass

code_name = code.__globals__["__name__"]
with pytest.raises(SystemExit) as e:
code.__globals__["__name__"] = "__main__"
simplecli.wrap(code)
code.__globals__["__name__"] = code_name
simplecli_wrap_main(code)

help_msg = e.value.args[0]
assert "--that-var" in help_msg
Expand All @@ -84,11 +79,8 @@ def test_wrap_simple_type_error(monkeypatch):
def code(a: int):
pass

code_name = code.__globals__["__name__"]
with pytest.raises(SystemExit, match="Too many positional"):
code.__globals__["__name__"] = "__main__"
simplecli.wrap(code)
code.__globals__["__name__"] = code_name
simplecli_wrap_main(code)


def test_wrap_simple_value_error(monkeypatch):
Expand All @@ -97,11 +89,8 @@ def test_wrap_simple_value_error(monkeypatch):
def code(a: int):
pass

code_name = code.__globals__["__name__"]
with pytest.raises(SystemExit, match="must be of type int"):
code.__globals__["__name__"] = "__main__"
simplecli.wrap(code)
code.__globals__["__name__"] = code_name
simplecli_wrap_main(code)


def test_wrap_version_absent(monkeypatch):
Expand All @@ -110,11 +99,8 @@ def test_wrap_version_absent(monkeypatch):
def code2(this_var: int): # stuff and things
pass

code2_name = code2.__globals__["__name__"]
with pytest.raises(SystemExit) as e:
code2.__globals__["__name__"] = "__main__"
simplecli.wrap(code2)
code2.__globals__["__name__"] = code2_name
simplecli_wrap_main(code2)

help_msg = e.value.args[0]
assert "Description:" not in help_msg
Expand All @@ -129,12 +115,9 @@ def test_wrap_version_exists(monkeypatch):
def code1(this_var: int): # stuff and things
pass

code1_name = code1.__globals__["__name__"]
with pytest.raises(SystemExit) as e:
code1.__globals__["__name__"] = "__main__"
code1.__globals__["__version__"] = "1.2.3"
simplecli.wrap(code1)
code1.__globals__["__name__"] = code1_name
simplecli_wrap_main(code1)

help_msg = e.value.args[0]
assert "super_script version 1.2.3" in help_msg
Expand All @@ -151,11 +134,8 @@ def code2(this_var: int): # stuff and things
this is a description
"""

code2_name = code2.__globals__["__name__"]
with pytest.raises(SystemExit) as e:
code2.__globals__["__name__"] = "__main__"
simplecli.wrap(code2)
code2.__globals__["__name__"] = code2_name
simplecli_wrap_main(code2)

help_msg = e.value.args[0]
assert "Description:" in help_msg
Expand All @@ -174,11 +154,8 @@ def code(
):
pass

code_name = code.__globals__["__name__"]
with pytest.raises(SystemExit) as e:
code.__globals__["__name__"] = "__main__"
simplecli.wrap(code)
code.__globals__["__name__"] = code_name
simplecli_wrap_main(code)

help_msg = e.value.args[0]
assert "--that-var" in help_msg
Expand All @@ -205,50 +182,56 @@ def code(a: int, b: int, c: typing.Union[int, float]):
def test_wrap_boolean_false(monkeypatch):
monkeypatch.setattr(sys, "argv", ["filename"])

@simplecli.wrap
def code(is_false: bool = False):
assert is_false is False

simplecli_wrap_main(code)


def test_wrap_boolean_false_invert(monkeypatch):
monkeypatch.setattr(sys, "argv", ["filename", "--invert"])

@simplecli.wrap
def code(invert: bool = False):
assert invert is True

simplecli_wrap_main(code)


def test_wrap_boolean_true(monkeypatch):
monkeypatch.setattr(sys, "argv", ["filename"])

@simplecli.wrap
def code(is_true: bool = True):
assert is_true is True

simplecli_wrap_main(code)


def test_wrap_boolean_true_invert(monkeypatch):
monkeypatch.setattr(sys, "argv", ["filename", "--invert"])

@simplecli.wrap
def code(invert: bool = True):
assert invert is False

simplecli_wrap_main(code)


def test_wrap_boolean_true_no_default_invert(monkeypatch):
monkeypatch.setattr(sys, "argv", ["filename", "--is-something"])

@simplecli.wrap
def code(is_something: bool):
assert is_something is True

simplecli_wrap_main(code)


def test_wrap_boolean_true_no_default_no_arg(monkeypatch):
monkeypatch.setattr(sys, "argv", ["filename"])

@simplecli.wrap
def code(is_something: bool):
assert is_something is False

simplecli_wrap_main(code)


def test_directly_called_wrap(monkeypatch):
import tests.nested_test as nt
Expand All @@ -263,11 +246,8 @@ def test_wrap_no_typehint_(monkeypatch):
def code1(foo):
pass

code1_name = code1.__globals__["__name__"]
with pytest.raises(SystemExit, match="parameters need type hints!"):
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
code1.__globals__["__name__"] = code1_name
simplecli_wrap_main(code1)


def test_wrap_no_typehint_no_arg(monkeypatch):
Expand All @@ -276,11 +256,8 @@ def test_wrap_no_typehint_no_arg(monkeypatch):
def code1(foo):
pass

code1_name = code1.__globals__["__name__"]
with pytest.raises(SystemExit, match="ERROR: All wrapped function "):
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
code1.__globals__["__name__"] = code1_name
simplecli_wrap_main(code1)


def test_wrap_no_typehint_kwarg(monkeypatch):
Expand All @@ -289,11 +266,8 @@ def test_wrap_no_typehint_kwarg(monkeypatch):
def code1(foo):
pass

code1_name = code1.__globals__["__name__"]
with pytest.raises(SystemExit, match="function parameters need type"):
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
code1.__globals__["__name__"] = code1_name
simplecli_wrap_main(code1)


def test_wrap_unsupported_type(monkeypatch):
Expand All @@ -302,11 +276,8 @@ def test_wrap_unsupported_type(monkeypatch):
def code1(foo: [str, int]):
pass

code1_name = code1.__globals__["__name__"]
with pytest.raises(SystemExit, match="UnsupportedType: list"):
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
code1.__globals__["__name__"] = code1_name
simplecli_wrap_main(code1)


def test_wrap_list_of_strings(capfd, monkeypatch):
Expand All @@ -315,11 +286,8 @@ def test_wrap_list_of_strings(capfd, monkeypatch):
def code1(foo: list[str]):
print(" | ".join(foo))

code1_name = code1.__globals__["__name__"]
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
simplecli_wrap_main(code1)
assert capfd.readouterr().out == "this | is | a | test\n"
code1.__globals__["__name__"] = code1_name


def test_wrap_list_of_ints(capfd, monkeypatch):
Expand All @@ -328,11 +296,8 @@ def test_wrap_list_of_ints(capfd, monkeypatch):
def code1(foo: list[int]):
print(sum(foo))

code1_name = code1.__globals__["__name__"]
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
simplecli_wrap_main(code1)
assert capfd.readouterr().out == "83\n"
code1.__globals__["__name__"] = code1_name


def test_wrap_set_of_ints_different(capfd, monkeypatch):
Expand All @@ -341,11 +306,8 @@ def test_wrap_set_of_ints_different(capfd, monkeypatch):
def code1(foo: set[int]):
print(sum(foo))

code1_name = code1.__globals__["__name__"]
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
simplecli_wrap_main(code1)
assert capfd.readouterr().out == "83\n"
code1.__globals__["__name__"] = code1_name


def test_wrap_set_of_ints_same(capfd, monkeypatch):
Expand All @@ -354,8 +316,5 @@ def test_wrap_set_of_ints_same(capfd, monkeypatch):
def code1(foo: set[int]):
print(sum(foo))

code1_name = code1.__globals__["__name__"]
code1.__globals__["__name__"] = "__main__"
simplecli.wrap(code1)
simplecli_wrap_main(code1)
assert capfd.readouterr().out == "9\n"
code1.__globals__["__name__"] = code1_name