From 1f6a8b454471b40b7a9d35031dbcb7d2b2a20d5d Mon Sep 17 00:00:00 2001 From: Patrick Ogenstad Date: Fri, 20 Sep 2024 19:10:14 +0200 Subject: [PATCH] Update typehints within ./tests --- pyproject.toml | 48 +++++++++++++-- tests/conftest.py | 10 ++-- tests/core/test_InitNornir.py | 37 ++++++------ tests/core/test_configuration.py | 18 +++--- tests/core/test_connections.py | 58 ++++++++++--------- tests/core/test_filter.py | 49 ++++++++-------- tests/core/test_inventory.py | 51 ++++++++-------- tests/core/test_processors.py | 4 +- tests/core/test_registered_plugins.py | 4 +- tests/core/test_tasks.py | 43 ++++++++------ .../inventory/test_simple_inventory.py | 4 +- tests/plugins/processors/test_serial.py | 16 ++--- tests/plugins/processors/test_threaded.py | 24 ++++---- tests/wrapper.py | 5 +- 14 files changed, 215 insertions(+), 156 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5ba68496..0d31152b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,8 +74,48 @@ warn_return_any = true warn_redundant_casts = true [[tool.mypy.overrides]] -module = "tests.*" -ignore_errors = true +module = "tests.conftest" +disable_error_code = [ + "return-value", +] + +[[tool.mypy.overrides]] +module = "tests.core.test_InitNornir" +disable_error_code = [ + "arg-type", +] + +[[tool.mypy.overrides]] +module = "tests.core.test_filter" +disable_error_code = [ + "arg-type", + "assignment", + "operator" +] + +[[tool.mypy.overrides]] +module = "tests.core.test_inventory" +disable_error_code = [ + "arg-type", + "index" +] + +[[tool.mypy.overrides]] +module = "tests.core.test_processors" +disable_error_code = [ + "assignment", + "index", + "var-annotated" +] + +[[tool.mypy.overrides]] +module = "tests.wrapper" +disable_error_code = [ + "import-untyped", + "misc", + "no-any-return" +] + [tool.ruff] line-length = 100 @@ -232,10 +272,6 @@ max-returns = 11 # like this so that we can reactivate them one by one. Alternatively ignored after further # # investigation if they are deemed to not make sense. # ################################################################################################## - "ANN001", # Missing type annotation for function argument - "ANN002", # Missing type annotation for `*args` - "ANN003", # Missing type annotation for `**kwargs` - "ANN201", # Missing return type annotation for public function "ARG001", # Unused function argument "B007", # Loop control variable `host` not used within loop body "C414", # Unnecessary `list` call within `sorted()` diff --git a/tests/conftest.py b/tests/conftest.py index f48ac62b..d9d2c4a4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,11 +23,11 @@ global_data = GlobalState(dry_run=True) -def inventory_from_yaml(): +def inventory_from_yaml() -> Inventory: dir_path = os.path.dirname(os.path.realpath(__file__)) yml = ruamel.yaml.YAML(typ="safe") - def get_connection_options(data) -> Dict[str, ConnectionOptions]: + def get_connection_options(data: Dict[str, Any]) -> Dict[str, ConnectionOptions]: cp = {} for cn, c in data.items(): cp[cn] = ConnectionOptions( @@ -119,17 +119,17 @@ def run(self, task: Task, hosts: List[Host]) -> AggregatedResult: @pytest.fixture(scope="session", autouse=True) -def inv(request): +def inv() -> Inventory: return inventory_from_yaml() @pytest.fixture(scope="session", autouse=True) -def nornir(request): +def nornir() -> Nornir: """Initializes nornir""" return Nornir(inventory=inventory_from_yaml(), runner=SerialRunner(), data=global_data) @pytest.fixture(scope="function", autouse=True) -def reset_data(): +def reset_data() -> None: global_data.dry_run = True global_data.reset_failed_hosts() diff --git a/tests/core/test_InitNornir.py b/tests/core/test_InitNornir.py index 6d863546..27cd2672 100644 --- a/tests/core/test_InitNornir.py +++ b/tests/core/test_InitNornir.py @@ -1,6 +1,7 @@ import logging import logging.config import os +from typing import Any, Dict import pytest @@ -36,19 +37,19 @@ } -def transform_func(host): +def transform_func(host: Host) -> None: host["processed_by_transform_function"] = True -def transform_func_with_options(host, a): +def transform_func_with_options(host: Host, a: Any) -> None: host["a"] = a class InventoryTest: - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Dict[str, Any]) -> None: pass - def load(self): + def load(self) -> Inventory: return Inventory( hosts=Hosts({"h1": Host("h1"), "h2": Host("h2"), "h3": Host("h3")}), groups=Groups({"g1": Group("g1")}), @@ -62,14 +63,14 @@ def load(self): class Test: - def test_InitNornir_bare(self): + def test_InitNornir_bare(self) -> None: os.chdir("tests/inventory_data/") nr = InitNornir() os.chdir("../../") assert len(nr.inventory.hosts) assert len(nr.inventory.groups) - def test_InitNornir_defaults(self): + def test_InitNornir_defaults(self) -> None: os.chdir("tests/inventory_data/") nr = InitNornir(inventory={"plugin": "inventory-test"}) os.chdir("../../") @@ -78,13 +79,13 @@ def test_InitNornir_defaults(self): assert len(nr.inventory.hosts) assert len(nr.inventory.groups) - def test_InitNornir_file(self): + def test_InitNornir_file(self) -> None: nr = InitNornir(config_file=os.path.join(dir_path, "a_config.yaml")) assert not nr.data.dry_run assert len(nr.inventory.hosts) assert len(nr.inventory.groups) - def test_InitNornir_programmatically(self): + def test_InitNornir_programmatically(self) -> None: nr = InitNornir( core={"raise_on_error": True}, inventory={ @@ -100,14 +101,14 @@ def test_InitNornir_programmatically(self): assert len(nr.inventory.hosts) assert len(nr.inventory.groups) - def test_InitNornir_override_partial_section(self): + def test_InitNornir_override_partial_section(self) -> None: nr = InitNornir( config_file=os.path.join(dir_path, "a_config.yaml"), core={"raise_on_error": True}, ) assert nr.config.core.raise_on_error - def test_InitNornir_combined(self): + def test_InitNornir_combined(self) -> None: nr = InitNornir( config_file=os.path.join(dir_path, "a_config.yaml"), core={"raise_on_error": True}, @@ -117,7 +118,7 @@ def test_InitNornir_combined(self): assert len(nr.inventory.hosts) assert len(nr.inventory.groups) - def test_InitNornir_different_transform_function_by_string(self): + def test_InitNornir_different_transform_function_by_string(self) -> None: nr = InitNornir( config_file=os.path.join(dir_path, "a_config.yaml"), inventory={ @@ -132,7 +133,7 @@ def test_InitNornir_different_transform_function_by_string(self): for host in nr.inventory.hosts.values(): assert host["processed_by_transform_function"] - def test_InitNornir_different_transform_function_by_string_with_options(self): + def test_InitNornir_different_transform_function_by_string_with_options(self) -> None: nr = InitNornir( config_file=os.path.join(dir_path, "a_config.yaml"), inventory={ @@ -148,7 +149,7 @@ def test_InitNornir_different_transform_function_by_string_with_options(self): for host in nr.inventory.hosts.values(): assert host["a"] == 1 - def test_InitNornir_different_transform_function_by_string_with_bad_options(self): + def test_InitNornir_different_transform_function_by_string_with_bad_options(self) -> None: with pytest.raises(TypeError): nr = InitNornir( config_file=os.path.join(dir_path, "a_config.yaml"), @@ -186,7 +187,7 @@ def cleanup(cls) -> None: def teardown_class(cls) -> None: cls.cleanup() - def test_InitNornir_logging_defaults(self): + def test_InitNornir_logging_defaults(self) -> None: self.cleanup() InitNornir( config_file=os.path.join(dir_path, "a_config.yaml"), @@ -197,7 +198,7 @@ def test_InitNornir_logging_defaults(self): assert len(nornir_logger.handlers) == 1 assert isinstance(nornir_logger.handlers[0], logging.FileHandler) - def test_InitNornir_logging_to_console(self): + def test_InitNornir_logging_to_console(self) -> None: self.cleanup() InitNornir( config_file=os.path.join(dir_path, "a_config.yaml"), @@ -210,7 +211,7 @@ def test_InitNornir_logging_to_console(self): assert any(isinstance(handler, logging.FileHandler) for handler in nornir_logger.handlers) assert any(isinstance(handler, logging.StreamHandler) for handler in nornir_logger.handlers) - def test_InitNornir_logging_disabled(self): + def test_InitNornir_logging_disabled(self) -> None: self.cleanup() InitNornir( config_file=os.path.join(dir_path, "a_config.yaml"), @@ -220,7 +221,7 @@ def test_InitNornir_logging_disabled(self): assert nornir_logger.level == logging.NOTSET - def test_InitNornir_logging_basicConfig(self): + def test_InitNornir_logging_basicConfig(self) -> None: self.cleanup() logging.basicConfig() with pytest.warns(ConflictingConfigurationWarning): @@ -231,7 +232,7 @@ def test_InitNornir_logging_basicConfig(self): assert nornir_logger.level == logging.INFO assert nornir_logger.hasHandlers() - def test_InitNornir_logging_dictConfig(self): + def test_InitNornir_logging_dictConfig(self) -> None: self.cleanup() logging.config.dictConfig(LOGGING_DICT) with pytest.warns(ConflictingConfigurationWarning): diff --git a/tests/core/test_configuration.py b/tests/core/test_configuration.py index ce44de1e..947c8dbc 100644 --- a/tests/core/test_configuration.py +++ b/tests/core/test_configuration.py @@ -10,7 +10,7 @@ class Test: - def test_config_defaults(self): + def test_config_defaults(self) -> None: c = Config() assert c.dict() == { "core": {"raise_on_error": False}, @@ -33,7 +33,7 @@ def test_config_defaults(self): "user_defined": {}, } - def test_config_from_dict_defaults(self): + def test_config_from_dict_defaults(self) -> None: c = Config.from_dict() assert c.dict() == { "core": {"raise_on_error": False}, @@ -56,7 +56,7 @@ def test_config_from_dict_defaults(self): "user_defined": {}, } - def test_config_basic(self): + def test_config_basic(self) -> None: c = Config.from_dict( inventory={"plugin": "an-inventory"}, runner={"plugin": "serial", "options": {"a": 1, "b": 2}}, @@ -84,14 +84,14 @@ def test_config_basic(self): "user_defined": {"my_opt": True}, } - def test_configuration_file_override_argument(self): + def test_configuration_file_override_argument(self) -> None: config = Config.from_file( os.path.join(dir_path, "config.yaml"), core={"raise_on_error": True}, ) assert config.core.raise_on_error - def test_configuration_file_override_env(self): + def test_configuration_file_override_env(self) -> None: os.environ["NORNIR_CORE_RAISE_ON_ERROR"] = "1" os.environ["NORNIR_SSH_CONFIG_FILE"] = "/user/ssh_config" config = Config.from_dict(inventory={"plugin": "an-inventory"}) @@ -100,22 +100,22 @@ def test_configuration_file_override_env(self): os.environ.pop("NORNIR_CORE_RAISE_ON_ERROR") os.environ.pop("NORNIR_SSH_CONFIG_FILE") - def test_configuration_bool_env(self): + def test_configuration_bool_env(self) -> None: os.environ["NORNIR_CORE_RAISE_ON_ERROR"] = "0" config = Config.from_dict(inventory={"plugin": "an-inventory"}) assert not config.core.raise_on_error - def test_get_user_defined_from_file(self): + def test_get_user_defined_from_file(self) -> None: config = Config.from_file(os.path.join(dir_path, "config.yaml")) assert config.user_defined["asd"] == "qwe" - def test_order_of_resolution_config_higher_than_env(self): + def test_order_of_resolution_config_higher_than_env(self) -> None: os.environ["NORNIR_CORE_RAISE_ON_ERROR"] = "1" config = Config.from_file(os.path.join(dir_path, "config.yaml")) os.environ.pop("NORNIR_CORE_RAISE_ON_ERROR") assert config.core.raise_on_error is False - def test_order_of_resolution_code_is_higher_than_env(self): + def test_order_of_resolution_code_is_higher_than_env(self) -> None: os.environ["NORNIR_CORE_RAISE_ON_ERROR"] = "0" config = Config.from_file( os.path.join(dir_path, "config.yaml"), core={"raise_on_error": True} diff --git a/tests/core/test_connections.py b/tests/core/test_connections.py index 9194b61d..81db0d54 100644 --- a/tests/core/test_connections.py +++ b/tests/core/test_connections.py @@ -2,6 +2,7 @@ import pytest +from nornir.core import Nornir from nornir.core.configuration import Config from nornir.core.exceptions import ( ConnectionAlreadyOpen, @@ -10,6 +11,7 @@ PluginNotRegistered, ) from nornir.core.plugins.connections import ConnectionPluginRegister +from nornir.core.task import Task class DummyConnectionPlugin: @@ -57,20 +59,24 @@ def open( extras: Optional[Dict[str, Any]] = None, configuration: Optional[Config] = None, ) -> None: - raise FailedConnection(f"Failed to open connection to {self.hostname}:{self.port}") + raise FailedConnection(f"Failed to open connection to {hostname}:{port}") def close(self) -> None: pass + @property + def connection(self) -> Any: + """Used to fullfill protocol specs.""" -def open_and_close_connection(task, nornir_config): + +def open_and_close_connection(task: Task, nornir_config: Config) -> None: task.host.open_connection("dummy", nornir_config) assert "dummy" in task.host.connections task.host.close_connection("dummy") assert "dummy" not in task.host.connections -def open_connection_twice(task, nornir_config): +def open_connection_twice(task: Task, nornir_config: Config) -> None: task.host.open_connection("dummy", nornir_config) assert "dummy" in task.host.connections try: @@ -81,7 +87,7 @@ def open_connection_twice(task, nornir_config): assert "dummy" not in task.host.connections -def close_not_opened_connection(task): +def close_not_opened_connection(task: Task) -> None: assert "dummy" not in task.host.connections try: task.host.close_connection("dummy") @@ -90,15 +96,15 @@ def close_not_opened_connection(task): assert "dummy" not in task.host.connections -def failed_connection(task, nornir_config): +def failed_connection(task: Task, nornir_config: Config) -> None: task.host.open_connection(FailedConnectionPlugin.name, nornir_config) -def a_task(task, nornir_config): +def a_task(task: Task, nornir_config: Config) -> None: task.host.get_connection("dummy", nornir_config) -def validate_params(task, conn, params, nornir_config): +def validate_params(task: Task, conn: str, params: Dict[str, Any], nornir_config: Config) -> None: task.host.get_connection(conn, nornir_config) for k, v in params.items(): assert getattr(task.host.connections[conn], k) == v @@ -113,36 +119,36 @@ def setup_class(cls) -> None: ConnectionPluginRegister.register("dummy_no_overrides", DummyConnectionPlugin) ConnectionPluginRegister.register(FailedConnectionPlugin.name, FailedConnectionPlugin) - def test_open_and_close_connection(self, nornir): + def test_open_and_close_connection(self, nornir: Nornir) -> None: nr = nornir.filter(name="dev2.group_1") r = nr.run(task=open_and_close_connection, nornir_config=nornir.config) assert len(r) == 1 assert not r.failed - def test_open_connection_twice(self, nornir): + def test_open_connection_twice(self, nornir: Nornir) -> None: nr = nornir.filter(name="dev2.group_1") r = nr.run(task=open_connection_twice, nornir_config=nornir.config) assert len(r) == 1 assert not r.failed - def test_close_not_opened_connection(self, nornir): + def test_close_not_opened_connection(self, nornir: Nornir) -> None: nr = nornir.filter(name="dev2.group_1") r = nr.run(task=close_not_opened_connection) assert len(r) == 1 assert not r.failed - def test_failed_connection(self, nornir): + def test_failed_connection(self, nornir: Nornir) -> None: nr = nornir.filter(name="dev2.group_1") nr.run(task=failed_connection, nornir_config=nornir.config) assert FailedConnectionPlugin.name not in nornir.inventory.hosts["dev2.group_1"].connections - def test_context_manager(self, nornir): + def test_context_manager(self, nornir: Nornir) -> None: with nornir.filter(name="dev2.group_1") as nr: nr.run(task=a_task, nornir_config=nornir.config) assert "dummy" in nr.inventory.hosts["dev2.group_1"].connections assert "dummy" not in nr.inventory.hosts["dev2.group_1"].connections - def test_validate_params_simple(self, nornir): + def test_validate_params_simple(self, nornir: Nornir) -> None: params = { "hostname": "localhost", "username": "root", @@ -161,7 +167,7 @@ def test_validate_params_simple(self, nornir): assert len(r) == 1 assert not r.failed - def test_validate_params_overrides(self, nornir): + def test_validate_params_overrides(self, nornir: Nornir) -> None: params = { "port": 65021, "hostname": "dummy_from_parent_group", @@ -180,7 +186,7 @@ def test_validate_params_overrides(self, nornir): assert len(r) == 1 assert not r.failed - def test_validate_params_overrides_groups(self, nornir): + def test_validate_params_overrides_groups(self, nornir: Nornir) -> None: params = { "port": 65021, "hostname": "dummy2_from_parent_group", @@ -201,48 +207,48 @@ def test_validate_params_overrides_groups(self, nornir): class TestConnectionPluginsRegistration: - def setup_method(self, method): + def setup_method(self) -> None: ConnectionPluginRegister.deregister_all() ConnectionPluginRegister.register("dummy", DummyConnectionPlugin) ConnectionPluginRegister.register("another_dummy", AnotherDummyConnectionPlugin) - def teardown_method(self, method): + def teardown_method(self) -> None: ConnectionPluginRegister.deregister_all() ConnectionPluginRegister.auto_register() - def test_count(self): + def test_count(self) -> None: assert len(ConnectionPluginRegister.available) == 2 - def test_register_new(self): + def test_register_new(self) -> None: ConnectionPluginRegister.register("new_dummy", DummyConnectionPlugin) assert "new_dummy" in ConnectionPluginRegister.available - def test_register_already_registered_same(self): + def test_register_already_registered_same(self) -> None: ConnectionPluginRegister.register("dummy", DummyConnectionPlugin) assert ConnectionPluginRegister.available["dummy"] == DummyConnectionPlugin - def test_register_already_registered_new(self): + def test_register_already_registered_new(self) -> None: with pytest.raises(PluginAlreadyRegistered): ConnectionPluginRegister.register("dummy", AnotherDummyConnectionPlugin) - def test_deregister_existing(self): + def test_deregister_existing(self) -> None: ConnectionPluginRegister.deregister("dummy") assert len(ConnectionPluginRegister.available) == 1 assert "dummy" not in ConnectionPluginRegister.available - def test_deregister_nonexistent(self): + def test_deregister_nonexistent(self) -> None: with pytest.raises(PluginNotRegistered): ConnectionPluginRegister.deregister("nonexistent_dummy") - def test_deregister_all(self): + def test_deregister_all(self) -> None: ConnectionPluginRegister.deregister_all() assert ConnectionPluginRegister.available == {} - def test_get_plugin(self): + def test_get_plugin(self) -> None: assert ConnectionPluginRegister.get_plugin("dummy") == DummyConnectionPlugin assert ConnectionPluginRegister.get_plugin("another_dummy") == AnotherDummyConnectionPlugin assert len(ConnectionPluginRegister.available) == 2 - def test_nonexistent_plugin(self): + def test_nonexistent_plugin(self) -> None: with pytest.raises(PluginNotRegistered): ConnectionPluginRegister.get_plugin("nonexistent_dummy") diff --git a/tests/core/test_filter.py b/tests/core/test_filter.py index 74c27f93..dea197a3 100644 --- a/tests/core/test_filter.py +++ b/tests/core/test_filter.py @@ -1,28 +1,29 @@ import pytest +from nornir.core import Nornir from nornir.core.filter import AND, OR, F class Test: - def test_simple(self, nornir): + def test_simple(self, nornir: Nornir) -> None: f = F(site="site1") filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) assert filtered == ["dev1.group_1", "dev2.group_1"] - def test_and(self, nornir): + def test_and(self, nornir: Nornir) -> None: f = F(site="site1") & F(role="www") filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) assert filtered == ["dev1.group_1"] - def test_or(self, nornir): + def test_or(self, nornir: Nornir) -> None: f = F(site="site1") | F(role="www") filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) assert filtered == ["dev1.group_1", "dev2.group_1", "dev3.group_2"] - def test_combined(self, nornir): + def test_combined(self, nornir: Nornir) -> None: f = F(site="site2") | (F(role="www") & F(my_var="comes_from_dev1.group_1")) filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) @@ -38,13 +39,13 @@ def test_combined(self, nornir): assert filtered == ["dev1.group_1"] - def test_contains(self, nornir): + def test_contains(self, nornir: Nornir) -> None: f = F(groups__contains="group_1") filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) assert filtered == ["dev1.group_1", "dev2.group_1"] - def test_negate(self, nornir): + def test_negate(self, nornir: Nornir) -> None: f = ~F(groups__contains="group_1") filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) @@ -55,13 +56,13 @@ def test_negate(self, nornir): "dev6.group_3", ] - def test_negate_and_second_negate(self, nornir): + def test_negate_and_second_negate(self, nornir: Nornir) -> None: f = F(site="site1") & ~F(role="www") filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) assert filtered == ["dev2.group_1"] - def test_negate_or_both_negate(self, nornir): + def test_negate_or_both_negate(self, nornir: Nornir) -> None: f = ~F(site="site1") | ~F(role="www") filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) @@ -73,31 +74,31 @@ def test_negate_or_both_negate(self, nornir): "dev6.group_3", ] - def test_nested_data_a_string(self, nornir): + def test_nested_data_a_string(self, nornir: Nornir) -> None: f = F(nested_data__a_string="asdasd") filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) assert filtered == ["dev1.group_1"] - def test_nested_data_a_string_contains(self, nornir): + def test_nested_data_a_string_contains(self, nornir: Nornir) -> None: f = F(nested_data__a_string__contains="asd") filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) assert filtered == ["dev1.group_1"] - def test_nested_data_a_dict_contains(self, nornir): + def test_nested_data_a_dict_contains(self, nornir: Nornir) -> None: f = F(nested_data__a_dict__contains="a") filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) assert filtered == ["dev1.group_1"] - def test_nested_data_a_dict_element(self, nornir): + def test_nested_data_a_dict_element(self, nornir: Nornir) -> None: f = F(nested_data__a_dict__a=1) filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) assert filtered == ["dev1.group_1"] - def test_nested_data_a_dict_doesnt_contain(self, nornir): + def test_nested_data_a_dict_doesnt_contain(self, nornir: Nornir) -> None: f = ~F(nested_data__a_dict__contains="a") filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) @@ -109,13 +110,13 @@ def test_nested_data_a_dict_doesnt_contain(self, nornir): "dev6.group_3", ] - def test_nested_data_a_list_contains(self, nornir): + def test_nested_data_a_list_contains(self, nornir: Nornir) -> None: f = F(nested_data__a_list__contains=2) filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) assert filtered == ["dev1.group_1", "dev2.group_1"] - def test_filtering_by_callable_has_parent_group(self, nornir): + def test_filtering_by_callable_has_parent_group(self, nornir: Nornir) -> None: f = F(has_parent_group="parent_group") filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) @@ -126,13 +127,13 @@ def test_filtering_by_callable_has_parent_group(self, nornir): "dev6.group_3", ] - def test_filtering_by_attribute_name(self, nornir): + def test_filtering_by_attribute_name(self, nornir: Nornir) -> None: f = F(name="dev1.group_1") filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) assert filtered == ["dev1.group_1"] - def test_filtering_string_in_list(self, nornir): + def test_filtering_string_in_list(self, nornir: Nornir) -> None: f = F(platform__in=["linux", "mock"]) filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) @@ -143,31 +144,31 @@ def test_filtering_string_in_list(self, nornir): "dev6.group_3", ] - def test_filtering_string_any(self, nornir): + def test_filtering_string_any(self, nornir: Nornir) -> None: f = F(some_string_to_test_any_all__any=["prefix", "other_prefix"]) filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) assert filtered == ["dev1.group_1", "dev3.group_2", "dev4.group_2"] - def test_filtering_list_any(self, nornir): + def test_filtering_list_any(self, nornir: Nornir) -> None: f = F(nested_data__a_list__any=[1, 3]) filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) assert filtered == ["dev1.group_1", "dev2.group_1"] - def test_filtering_list_all(self, nornir): + def test_filtering_list_all(self, nornir: Nornir) -> None: f = F(nested_data__a_list__all=[1, 2]) filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) assert filtered == ["dev1.group_1"] - def test_filter_wrong_attribute_for_type(self, nornir): + def test_filter_wrong_attribute_for_type(self, nornir: Nornir) -> None: f = F(port__startswith="a") filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) assert filtered == [] - def test_eq__on_not_existing_key(self, nornir): + def test_eq__on_not_existing_key(self, nornir: Nornir) -> None: f = F(not_existing__eq="test") filtered = sorted(list((nornir.inventory.filter(f).hosts.keys()))) @@ -183,7 +184,7 @@ def test_eq__on_not_existing_key(self, nornir): (F(site="site1") | F(role="www"), OR(F(role="www"), F(site="site1"))), ], ) - def test_compare_filter_equal(self, filter_a, filter_b): + def test_compare_filter_equal(self, filter_a: F, filter_b: F) -> None: assert filter_a == filter_b @pytest.mark.parametrize( @@ -194,5 +195,5 @@ def test_compare_filter_equal(self, filter_a, filter_b): (F(site="site1"), ~F(site="site1")), ], ) - def test_compare_filter_not_equal(self, filter_a, filter_b): + def test_compare_filter_not_equal(self, filter_a: F, filter_b: F) -> None: assert filter_a != filter_b diff --git a/tests/core/test_inventory.py b/tests/core/test_inventory.py index 74ddf66c..0f9c5702 100644 --- a/tests/core/test_inventory.py +++ b/tests/core/test_inventory.py @@ -4,6 +4,7 @@ import ruamel.yaml from nornir.core import inventory +from nornir.core.inventory import Host yaml = ruamel.yaml.YAML(typ="safe") dir_path = os.path.dirname(os.path.realpath(__file__)) @@ -17,7 +18,7 @@ class Test: - def test_host(self): + def test_host(self) -> None: h = inventory.Host(name="host1", hostname="host1") assert h.hostname == "host1" assert h.port is None @@ -43,7 +44,7 @@ def test_host(self): assert h.platform == "fake" assert h.data == data - def test_inventory(self): + def test_inventory(self) -> None: g1 = inventory.Group(name="g1") g2 = inventory.Group(name="g2", groups=inventory.ParentGroups([g1])) h1 = inventory.Host(name="h1", groups=inventory.ParentGroups([g1, g2])) @@ -58,7 +59,7 @@ def test_inventory(self): assert inv.groups["g1"] in inv.hosts["h1"].groups assert inv.groups["g1"] in inv.groups["g2"].groups - def test_inventory_data(self, inv): + def test_inventory_data(self, inv: inventory.Inventory) -> None: """Test Host values()/keys()/items()""" h = inv.hosts["dev1.group_1"] assert "comes_from_dev1.group_1" in h.values() @@ -67,7 +68,7 @@ def test_inventory_data(self, inv): assert "only_default" in h.keys() assert dict(h.items())["my_var"] == "comes_from_dev1.group_1" - def test_inventory_dict(self, inv): + def test_inventory_dict(self, inv: inventory.Inventory) -> None: assert inv.dict() == { "defaults": { "connection_options": { @@ -319,7 +320,7 @@ def test_inventory_dict(self, inv): }, } - def test_extended_data(self, inv): + def test_extended_data(self, inv: inventory.Inventory) -> None: assert inv.hosts["dev1.group_1"].extended_data() == { "a_false_var": False, "a_var": "blah", @@ -356,7 +357,7 @@ def test_extended_data(self, inv): "site": "site2", } - def test_parent_groups_extended(self, inv): + def test_parent_groups_extended(self, inv: inventory.Inventory) -> None: assert inv.hosts["dev1.group_1"].extended_groups() == [ inv.groups["group_1"], inv.groups["parent_group"], @@ -371,7 +372,7 @@ def test_parent_groups_extended(self, inv): inv.groups["parent_group"], ] - def test_filtering(self, inv): + def test_filtering(self, inv: inventory.Inventory) -> None: unfiltered = sorted(list(inv.hosts.keys())) assert unfiltered == [ "dev1.group_1", @@ -391,23 +392,23 @@ def test_filtering(self, inv): www_site1 = sorted(list(inv.filter(role="www").filter(site="site1").hosts.keys())) assert www_site1 == ["dev1.group_1"] - def test_filtering_func(self, inv): + def test_filtering_func(self, inv: inventory.Inventory) -> None: long_names = sorted( list(inv.filter(filter_func=lambda x: len(x["my_var"]) > 20).hosts.keys()) ) assert long_names == ["dev1.group_1", "dev4.group_2", "dev6.group_3"] - def longer_than(dev, length) -> bool: + def longer_than(dev: Host, length: int) -> bool: return len(dev["my_var"]) > length long_names = sorted(list(inv.filter(filter_func=longer_than, length=20).hosts.keys())) assert long_names == ["dev1.group_1", "dev4.group_2", "dev6.group_3"] - def test_filter_unique_keys(self, inv): + def test_filter_unique_keys(self, inv: inventory.Inventory) -> None: filtered = sorted(list(inv.filter(www_server="nginx").hosts.keys())) assert filtered == ["dev1.group_1"] - def test_var_resolution(self, inv): + def test_var_resolution(self, inv: inventory.Inventory) -> None: assert inv.hosts["dev1.group_1"]["my_var"] == "comes_from_dev1.group_1" assert inv.hosts["dev2.group_1"]["my_var"] == "comes_from_group_1" assert inv.hosts["dev3.group_2"]["my_var"] == "comes_from_defaults" @@ -423,7 +424,7 @@ def test_var_resolution(self, inv): inv.hosts["dev3.group_2"].data["my_var"] assert inv.hosts["dev4.group_2"].data["my_var"] == "comes_from_dev4.group_2" - def test_attributes_resolution(self, inv): + def test_attributes_resolution(self, inv: inventory.Inventory) -> None: assert inv.hosts["dev1.group_1"].password == "a_password" assert inv.hosts["dev2.group_1"].password == "from_group1" assert inv.hosts["dev3.group_2"].password == "docker" @@ -431,13 +432,13 @@ def test_attributes_resolution(self, inv): assert inv.hosts["dev5.no_group"].password == "docker" assert inv.hosts["dev6.group_3"].password == "from_parent_group" - def test_has_parents(self, inv): + def test_has_parents(self, inv: inventory.Inventory) -> None: assert inv.hosts["dev1.group_1"].has_parent_group(inv.groups["group_1"]) assert not inv.hosts["dev1.group_1"].has_parent_group(inv.groups["group_2"]) assert inv.hosts["dev1.group_1"].has_parent_group("group_1") assert not inv.hosts["dev1.group_1"].has_parent_group("group_2") - def test_get_connection_parameters(self, inv): + def test_get_connection_parameters(self, inv: inventory.Inventory) -> None: p1 = inv.hosts["dev1.group_1"].get_connection_parameters("dummy") assert p1.port == 65020 assert p1.hostname == "dummy_from_host" @@ -467,7 +468,7 @@ def test_get_connection_parameters(self, inv): assert p4.platform == "linux" assert p4.extras == {"blah": "from_defaults"} - def test_defaults(self, inv): + def test_defaults(self, inv: inventory.Inventory) -> None: inv.defaults.password = "asd" assert inv.defaults.password == "asd" assert inv.hosts["dev2.group_1"].password == "from_group1" @@ -475,7 +476,7 @@ def test_defaults(self, inv): assert inv.hosts["dev4.group_2"].password == "from_parent_group" assert inv.hosts["dev5.no_group"].password == "asd" - def test_children_of_str(self, inv): + def test_children_of_str(self, inv: inventory.Inventory) -> None: assert inv.children_of_group("parent_group") == { inv.hosts["dev1.group_1"], inv.hosts["dev2.group_1"], @@ -495,7 +496,7 @@ def test_children_of_str(self, inv): assert inv.children_of_group("blah") == set() - def test_children_of_obj(self, inv): + def test_children_of_obj(self, inv: inventory.Inventory) -> None: assert inv.children_of_group(inv.groups["parent_group"]) == { inv.hosts["dev1.group_1"], inv.hosts["dev2.group_1"], @@ -513,7 +514,7 @@ def test_children_of_obj(self, inv): inv.hosts["dev3.group_2"], } - def test_add_host(self): + def test_add_host(self) -> None: data = {"test_var": "test_value"} defaults = inventory.Defaults(data=data) g1 = inventory.Group(name="g1") @@ -538,7 +539,7 @@ def test_add_host(self): assert inv.hosts["h3"].platform == "TestPlatform" assert inv.hosts["h3"].connection_options["netmiko"].extras["device_type"] == "cisco_ios" - def test_add_group(self): + def test_add_group(self) -> None: connection_options = {"username": "test_user", "password": "test_pass"} data = {"test_var": "test_value"} defaults = inventory.Defaults(data=data, connection_options=connection_options) @@ -564,7 +565,7 @@ def test_add_group(self): assert inv.groups["g3"].defaults.data.get("test_var") == "test_value" assert inv.groups["g3"].connection_options["netmiko"].extras["device_type"] == "cisco_ios" - def test_dict(self, inv): + def test_dict(self, inv: inventory.Inventory) -> None: inventory_dict = inv.dict() def_extras = inventory_dict["defaults"]["connection_options"]["dummy"]["extras"] grp_data = inventory_dict["groups"]["group_1"]["data"] @@ -575,7 +576,7 @@ def test_dict(self, inv): assert "my_var" and "site" in grp_data assert "www_server" and "role" in host_data - def test_get_defaults_dict(self, inv): + def test_get_defaults_dict(self, inv: inventory.Inventory) -> None: defaults_dict = inv.defaults.dict() con_options = defaults_dict["connection_options"]["dummy"] assert isinstance(defaults_dict, dict) @@ -583,13 +584,13 @@ def test_get_defaults_dict(self, inv): assert con_options["hostname"] == "dummy_from_defaults" assert "blah" in con_options["extras"] - def test_get_groups_dict(self, inv): + def test_get_groups_dict(self, inv: inventory.Inventory) -> None: groups_dict = {n: g.dict() for n, g in inv.groups.items()} assert isinstance(groups_dict, dict) assert groups_dict["group_1"]["password"] == "from_group1" assert groups_dict["group_2"]["data"]["site"] == "site2" - def test_get_hosts_dict(self, inv): + def test_get_hosts_dict(self, inv: inventory.Inventory) -> None: hosts_dict = {n: h.dict() for n, h in inv.hosts.items()} dev1_groups = hosts_dict["dev1.group_1"]["groups"] dev2_paramiko_opts = hosts_dict["dev2.group_1"]["connection_options"]["paramiko"] @@ -598,7 +599,7 @@ def test_get_hosts_dict(self, inv): assert dev2_paramiko_opts["username"] == "root" assert "dev3.group_2" in hosts_dict - def test_add_group_to_host_runtime(self): + def test_add_group_to_host_runtime(self) -> None: orig_data = {"var1": "val1"} data = {"var3": "val3"} g1 = inventory.Group(name="g1", data=orig_data) @@ -618,7 +619,7 @@ def test_add_group_to_host_runtime(self): assert g3 in h1.groups assert h1.get("var3", None) == "val3" - def test_remove_group_from_host(self): + def test_remove_group_from_host(self) -> None: data = {"var3": "val3"} orig_data = {"var1": "val1"} g1 = inventory.Group(name="g1", data=orig_data) diff --git a/tests/core/test_processors.py b/tests/core/test_processors.py index eb27b663..1183a32f 100644 --- a/tests/core/test_processors.py +++ b/tests/core/test_processors.py @@ -23,7 +23,7 @@ def mock_subtask(task: Task) -> Result: class MockProcessor: - def __init__(self, data: Dict[str, None]) -> None: + def __init__(self, data: Dict[str, Dict[str, Any]]) -> None: self.data = data def task_started(self, task: Task) -> None: @@ -48,7 +48,7 @@ def _get_subtask_dict(self, task: Task, host: Host) -> Dict[str, Any]: parents.insert(0, parent.name) parent = parent.parent_task - data = self.data[parents[0]][host.name]["subtasks"] + data: Dict[str, Any] = self.data[parents[0]][host.name]["subtasks"] for p in parents[1:]: data = data[p]["subtasks"] return data diff --git a/tests/core/test_registered_plugins.py b/tests/core/test_registered_plugins.py index ab2996b6..a3e405a8 100644 --- a/tests/core/test_registered_plugins.py +++ b/tests/core/test_registered_plugins.py @@ -7,7 +7,7 @@ class Test: - def test_registered_runners(self): + def test_registered_runners(self) -> None: RunnersPluginRegister.deregister_all() RunnersPluginRegister.auto_register() assert RunnersPluginRegister.available == { @@ -15,7 +15,7 @@ def test_registered_runners(self): "serial": SerialRunner, } - def test_registered_inventory(self): + def test_registered_inventory(self) -> None: InventoryPluginRegister.deregister_all() InventoryPluginRegister.auto_register() assert InventoryPluginRegister.available == { diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index f78720a2..a96506b7 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -1,25 +1,27 @@ import logging +from typing import List, Optional +from nornir.core import Nornir from nornir.core.exceptions import NornirSubTaskError -from nornir.core.task import Result +from nornir.core.task import Result, Task class CustomException(Exception): pass -def a_task_for_testing(task, fail_on=None): +def a_task_for_testing(task: Task, fail_on: Optional[List[str]] = None) -> Result: fail_on = fail_on or [] if task.host.name in fail_on: raise CustomException() return Result(host=task.host, stdout=task.host.name) -def a_failed_task_for_testing(task): +def a_failed_task_for_testing(task: Task) -> Result: return Result(host=task.host, stdout=task.host.name, failed=True) -def a_failed_task_for_testing_overrides_severity(task): +def a_failed_task_for_testing_overrides_severity(task: Task) -> Result: return Result( host=task.host, stdout=task.host.name, @@ -28,18 +30,22 @@ def a_failed_task_for_testing_overrides_severity(task): ) -def a_task_to_test_dry_run(task, expected_dry_run_value, dry_run=None): +def a_task_to_test_dry_run( + task: Task, expected_dry_run_value: bool, dry_run: Optional[bool] = None +) -> None: assert task.is_dry_run(dry_run) is expected_dry_run_value -def sub_task_for_testing(task, fail_on=None): +def sub_task_for_testing(task: Task, fail_on: Optional[List[str]] = None) -> None: task.run( a_task_for_testing, fail_on=fail_on, ) -def sub_task_for_testing_overrides_severity(task, fail_on=None): +def sub_task_for_testing_overrides_severity( + task: Task, fail_on: Optional[List[str]] = None +) -> None: task.run( a_task_for_testing, fail_on=fail_on, @@ -47,34 +53,36 @@ def sub_task_for_testing_overrides_severity(task, fail_on=None): ) -def fail_command_subtask_no_capture(task, fail_on=None): +def fail_command_subtask_no_capture(task: Task, fail_on: Optional[List[str]] = None) -> str: task.run(a_task_for_testing, fail_on=fail_on) return "I shouldn't be here" -def fail_command_subtask_capture(task, fail_on=None): +def fail_command_subtask_capture(task: Task, fail_on: Optional[List[str]] = None) -> Optional[str]: try: task.run(a_task_for_testing, fail_on=fail_on) except Exception: return "I captured this succcessfully" + return None class Test: - def test_task(self, nornir): + def test_task(self, nornir: Nornir) -> None: result = nornir.run(a_task_for_testing) assert result for h, r in result.items(): assert r.stdout.strip() == h - def test_sub_task(self, nornir): + def test_sub_task(self, nornir: Nornir) -> None: result = nornir.run(sub_task_for_testing) assert result for h, r in result.items(): assert r[0].name == "sub_task_for_testing" assert r[1].name == "a_task_for_testing" + assert r[1].stdout is not None assert h == r[1].stdout.strip() - def test_skip_failed_host(self, nornir): + def test_skip_failed_host(self, nornir: Nornir) -> None: result = nornir.run(sub_task_for_testing, fail_on=["dev3.group_2"]) assert result.failed assert "dev3.group_2" in result @@ -84,13 +92,14 @@ def test_skip_failed_host(self, nornir): assert r.failed else: assert not r.failed + assert r[1].stdout is not None assert h == r[1].stdout.strip() result = nornir.run(a_task_for_testing) assert not result.failed assert "dev3.group_2" not in result - def test_run_on(self, nornir): + def test_run_on(self, nornir: Nornir) -> None: result = nornir.run(a_task_for_testing, fail_on=["dev3.group_2"]) assert result.failed assert "dev3.group_2" in result @@ -115,7 +124,7 @@ def test_run_on(self, nornir): assert "dev3.group_2" not in result assert "dev1.group_1" in result - def test_severity(self, nornir): + def test_severity(self, nornir: Nornir) -> None: r = nornir.run(a_task_for_testing) for host, result in r.items(): assert result[0].severity_level == logging.INFO @@ -159,7 +168,7 @@ def test_severity(self, nornir): # Reset all failed host for next test nornir.data.reset_failed_hosts() - def test_dry_run(self, nornir): + def test_dry_run(self, nornir: Nornir) -> None: host = nornir.filter(name="dev3.group_2") r = host.run(a_task_to_test_dry_run, expected_dry_run_value=True) assert not r["dev3.group_2"].failed @@ -175,14 +184,14 @@ def test_dry_run(self, nornir): r = host.run(a_task_to_test_dry_run, expected_dry_run_value=False) assert r["dev3.group_2"].failed - def test_subtask_exception_no_capture(self, nornir): + def test_subtask_exception_no_capture(self, nornir: Nornir) -> None: host = nornir.filter(name="dev1.group_1") r = host.run(task=fail_command_subtask_no_capture, fail_on=["dev1.group_1"]) assert r.failed assert r["dev1.group_1"][0].exception.__class__ is NornirSubTaskError assert r["dev1.group_1"][1].exception.__class__ is CustomException - def test_subtask_exception_capture(self, nornir): + def test_subtask_exception_capture(self, nornir: Nornir) -> None: host = nornir.filter(name="dev1.group_1") r = host.run(task=fail_command_subtask_capture, fail_on=["dev1.group_1"]) assert r.failed diff --git a/tests/plugins/inventory/test_simple_inventory.py b/tests/plugins/inventory/test_simple_inventory.py index 8eae3d69..18bcddb5 100644 --- a/tests/plugins/inventory/test_simple_inventory.py +++ b/tests/plugins/inventory/test_simple_inventory.py @@ -6,7 +6,7 @@ class Test: - def test(self): + def test(self) -> None: host_file = f"{dir_path}/data/hosts.yaml" group_file = f"{dir_path}/data/groups.yaml" defaults_file = f"{dir_path}/data/defaults.yaml" @@ -232,7 +232,7 @@ def test(self): }, } - def test_simple_inventory_empty(self): + def test_simple_inventory_empty(self) -> None: """Verify completely empty groups.yaml and defaults.yaml doesn't generate exception.""" host_file = f"{dir_path}/data/hosts-nogroups.yaml" group_file = f"{dir_path}/data/groups-empty.yaml" diff --git a/tests/plugins/processors/test_serial.py b/tests/plugins/processors/test_serial.py index 47199d64..1cd6d364 100644 --- a/tests/plugins/processors/test_serial.py +++ b/tests/plugins/processors/test_serial.py @@ -1,6 +1,8 @@ import datetime import time +from nornir.core import Nornir +from nornir.core.task import Task from nornir.plugins.runners import SerialRunner @@ -8,32 +10,32 @@ class CustomException(Exception): pass -def a_task_for_testing(task, command): +def a_task_for_testing(task: Task, command: str) -> None: if command == "failme": raise CustomException() -def blocking_task(task, wait): +def blocking_task(task: Task, wait: float) -> None: time.sleep(wait) -def failing_task_simple(task): +def failing_task_simple(task: Task) -> None: raise Exception(task.host.name) -def failing_task_complex(task): +def failing_task_complex(task: Task) -> None: a_task_for_testing(task, command="failme") class TestSerialRunner: - def test_blocking_task_single_thread(self, nornir): + def test_blocking_task_single_thread(self, nornir: Nornir) -> None: t1 = datetime.datetime.now() nornir.with_runner(SerialRunner()).run(blocking_task, wait=0.5) t2 = datetime.datetime.now() delta = t2 - t1 assert delta.seconds == 3, delta - def test_failing_task_simple_singlethread(self, nornir): + def test_failing_task_simple_singlethread(self, nornir: Nornir) -> None: result = nornir.with_runner(SerialRunner()).run(failing_task_simple) processed = False for k, v in result.items(): @@ -42,7 +44,7 @@ def test_failing_task_simple_singlethread(self, nornir): assert isinstance(v.exception, Exception), v assert processed - def test_failing_task_complex_singlethread(self, nornir): + def test_failing_task_complex_singlethread(self, nornir: Nornir) -> None: result = nornir.with_runner(SerialRunner()).run(failing_task_complex) processed = False for k, v in result.items(): diff --git a/tests/plugins/processors/test_threaded.py b/tests/plugins/processors/test_threaded.py index b0974c6d..75849bb1 100644 --- a/tests/plugins/processors/test_threaded.py +++ b/tests/plugins/processors/test_threaded.py @@ -3,7 +3,9 @@ import pytest +from nornir.core import Nornir from nornir.core.exceptions import NornirExecutionError +from nornir.core.task import Task from nornir.plugins.runners import ThreadedRunner NUM_WORKERS = 20 @@ -13,40 +15,40 @@ class CustomException(Exception): pass -def a_task_for_testing(task, command): +def a_task_for_testing(task: Task, command: str) -> None: if command == "failme": raise CustomException() -def blocking_task(task, wait): +def blocking_task(task: Task, wait: float) -> None: time.sleep(wait) -def failing_task_simple(task): +def failing_task_simple(task: Task) -> None: raise Exception(task.host.name) -def failing_task_complex(task): +def failing_task_complex(task: Task) -> None: a_task_for_testing(task, command="failme") -def change_data(task): +def change_data(task: Task) -> None: task.host["my_changed_var"] = task.host.name -def verify_data_change(task): +def verify_data_change(task: Task) -> None: assert task.host["my_changed_var"] == task.host.name class Test: - def test_blocking_task_multithreading(self, nornir): + def test_blocking_task_multithreading(self, nornir: Nornir) -> None: t1 = datetime.datetime.now() nornir.with_runner(ThreadedRunner(num_workers=NUM_WORKERS)).run(blocking_task, wait=2) t2 = datetime.datetime.now() delta = t2 - t1 assert delta.seconds == 2, delta - def test_failing_task_simple_multithread(self, nornir): + def test_failing_task_simple_multithread(self, nornir: Nornir) -> None: result = nornir.with_runner(ThreadedRunner(num_workers=NUM_WORKERS)).run( failing_task_simple, ) @@ -57,7 +59,7 @@ def test_failing_task_simple_multithread(self, nornir): assert isinstance(v.exception, Exception), v assert processed - def test_failing_task_complex_multithread(self, nornir): + def test_failing_task_complex_multithread(self, nornir: Nornir) -> None: result = nornir.with_runner(ThreadedRunner(num_workers=NUM_WORKERS)).run( failing_task_complex, ) @@ -68,7 +70,7 @@ def test_failing_task_complex_multithread(self, nornir): assert isinstance(v.exception, CustomException), v assert processed - def test_failing_task_complex_multithread_raise_on_error(self, nornir): + def test_failing_task_complex_multithread_raise_on_error(self, nornir: Nornir) -> None: with pytest.raises(NornirExecutionError) as e: nornir.with_runner(ThreadedRunner(num_workers=NUM_WORKERS)).run( failing_task_complex, raise_on_error=True @@ -77,7 +79,7 @@ def test_failing_task_complex_multithread_raise_on_error(self, nornir): assert isinstance(k, str), k assert isinstance(v.exception, CustomException), v - def test_change_data_in_thread(self, nornir): + def test_change_data_in_thread(self, nornir: Nornir) -> None: nornir.with_runner(ThreadedRunner(num_workers=NUM_WORKERS)).run( change_data, ) diff --git a/tests/wrapper.py b/tests/wrapper.py index d23a85ad..40acbe01 100644 --- a/tests/wrapper.py +++ b/tests/wrapper.py @@ -1,10 +1,11 @@ import sys from io import StringIO +from typing import Any, Callable, Dict from decorator import decorator -def wrap_cli_test(output, save_output=False): +def wrap_cli_test(output: str, save_output: bool = False) -> Callable[[Callable[..., Any]], None]: """ This decorator captures the stdout and stder and compare it with the contects of the specified files. @@ -15,7 +16,7 @@ def wrap_cli_test(output, save_output=False): """ @decorator - def run_test(func, *args, **kwargs) -> None: + def run_test(func: Callable[..., Any], *args: Any, **kwargs: Dict[str, Any]) -> Any: stdout = StringIO() backup_stdout = sys.stdout sys.stdout = stdout