diff --git a/pyproject.toml b/pyproject.toml index e8b9a30..703bd05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "python-main" -version = "1.0.1" +version = "1.0.2" homepage = "https://github.com/flipbit03/main" description = "Decorator which runs the tagged function if the current module is being run as a script. No more `if __name__ == \"__main__\"` madness." authors = ["Cadu "] diff --git a/python_main/__init__.py b/python_main/__init__.py index 2814897..9b4c3be 100644 --- a/python_main/__init__.py +++ b/python_main/__init__.py @@ -1,13 +1,10 @@ -import builtins -import inspect from typing import Callable, Optional +__RAN_AS_SCRIPT_MODULE = "__main__" +__CALLABLE_MODULE_PROP = "__module__" -def main(f: Callable[[], Optional[int]]) -> None: - curr_frame = inspect.currentframe() - assert curr_frame is not None - assert curr_frame.f_back is not None - upper_frame = curr_frame.f_back - if upper_frame.f_locals["__name__"] == "__main__": - builtins.exit(f() or 0) +def main(f: Callable[[], Optional[int]]) -> Callable: + if getattr(f, __CALLABLE_MODULE_PROP) == __RAN_AS_SCRIPT_MODULE: + f() + return f diff --git a/tests/test_basic.py b/tests/test_basic.py index eba2837..21ae8d7 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,19 +1,15 @@ +import builtins + import pytest from python_main import main - -class NotSet: - pass - - -EXIT_CODE_RECEIVED = NotSet +EXIT_CODE_RECEIVED = -1 @pytest.fixture def mock_exit(): - import builtins - + global EXIT_CODE_RECEIVED original_exit = builtins.exit def mock_exit(code): @@ -27,7 +23,14 @@ def mock_exit(code): # Clean up builtins.exit = original_exit - EXIT_CODE_RECEIVED = NotSet + EXIT_CODE_RECEIVED = -1 + + +def __my_main_func(): + """ + The answer to life, the universe, and everything. + """ + builtins.exit(42) def test_assert_function_actually_gets_called(mock_exit): @@ -35,17 +38,31 @@ def test_assert_function_actually_gets_called(mock_exit): Assert that the @main decorator actually calls the function if the module is being run as a script. """ - # We patch __name__ here because doing so via a proper pytest fixture would be _A Lot Of Work (TM)_, because - # of the insane amount of stack manipulation that would be required to get the desired effect. - # The "noqa" flag here is important, or else our pre-commit hooks (flake) will remove this assignment. - __name__ = "__main__" # noqa + # We patch my_main_func's __module__ here so that we can emulate that it comes from a module which + # is being run as a script/ + __my_main_func_original_module = __my_main_func.__module__ + __my_main_func.__module__ = "__main__" - @main - def my_main_func(): - """ - The answer to life, the universe, and everything. - """ - return 42 + # Decorate it + main(__my_main_func) # Ensure that our main function was able to call mock_exit with the expected value. + global EXIT_CODE_RECEIVED assert EXIT_CODE_RECEIVED == 42 + + # Restore + __my_main_func.__module__ = __my_main_func_original_module + + +def test_assert_function_does_not_get_called(mock_exit): + """ + Assert that our decorated function does not get called in normal circumstances + """ + + # Call the function, which is coming from a pytest execution and being imported as a module + function_returned = main(__my_main_func) + + # Exit code will not have been set. + global EXIT_CODE_RECEIVED + assert EXIT_CODE_RECEIVED == -1 + assert function_returned == __my_main_func