From 19ac3a73ebe651f6fd0be4c67b8160c67e40012f Mon Sep 17 00:00:00 2001 From: Patrick Ogenstad Date: Thu, 19 Sep 2024 18:47:03 +0200 Subject: [PATCH] Add annotations for returned data Start to add type annotations within the tests --- pyproject.toml | 2 -- tests/conftest.py | 12 ++++++++---- tests/core/test_InitNornir.py | 4 ++-- tests/core/test_connections.py | 2 +- tests/core/test_inventory.py | 2 +- tests/wrapper.py | 2 +- 6 files changed, 13 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6f0c2496..5ba68496 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -236,8 +236,6 @@ max-returns = 11 "ANN002", # Missing type annotation for `*args` "ANN003", # Missing type annotation for `**kwargs` "ANN201", # Missing return type annotation for public function - "ANN202", # Missing return type annotation for private function - "ANN206", # Missing return type annotation for classmethod "ARG001", # Unused function argument "B007", # Loop control variable `host` not used within loop body "C414", # Unnecessary `list` call within `sorted()` diff --git a/tests/conftest.py b/tests/conftest.py index f3f389b9..a4648eca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ import os -from typing import List +from typing import Any, Dict, List, Type, TypeVar, Union import pytest import ruamel.yaml @@ -18,6 +18,8 @@ from nornir.core.state import GlobalState from nornir.core.task import AggregatedResult, Task +ElementType = TypeVar("ElementType", bound=Union[Group, Host]) + global_data = GlobalState(dry_run=True) @@ -25,7 +27,7 @@ def inventory_from_yaml(): dir_path = os.path.dirname(os.path.realpath(__file__)) yml = ruamel.yaml.YAML(typ="safe") - def get_connection_options(data): + def get_connection_options(data) -> Dict[str, ConnectionOptions]: cp = {} for cn, c in data.items(): cp[cn] = ConnectionOptions( @@ -38,7 +40,7 @@ def get_connection_options(data): ) return cp - def get_defaults(): + def get_defaults() -> Defaults: defaults_file = f"{dir_path}/inventory_data/defaults.yaml" with open(defaults_file, "r") as f: defaults_dict = yml.load(f) @@ -55,7 +57,9 @@ def get_defaults(): ), ) - def get_inventory_element(typ, data, name, defaults): + def get_inventory_element( + typ: Type[ElementType], data: dict[str, Any], name: str, defaults: Union[Defaults, None] + ) -> ElementType: return typ( name=name, hostname=data.get("hostname"), diff --git a/tests/core/test_InitNornir.py b/tests/core/test_InitNornir.py index 745eab09..6d863546 100644 --- a/tests/core/test_InitNornir.py +++ b/tests/core/test_InitNornir.py @@ -167,7 +167,7 @@ def test_InitNornir_different_transform_function_by_string_with_bad_options(self class TestLogging: @classmethod - def cleanup(cls): + def cleanup(cls) -> None: # this does not work as setup_method, because pytest injects # _pytest.logging.LogCaptureHandler handler to the root logger # and StreamHandler to _pytest.capture.EncodedFile to other loggers @@ -183,7 +183,7 @@ def cleanup(cls): logger_.setLevel(logging.NOTSET) @classmethod - def teardown_class(cls): + def teardown_class(cls) -> None: cls.cleanup() def test_InitNornir_logging_defaults(self): diff --git a/tests/core/test_connections.py b/tests/core/test_connections.py index 1f773026..9194b61d 100644 --- a/tests/core/test_connections.py +++ b/tests/core/test_connections.py @@ -106,7 +106,7 @@ def validate_params(task, conn, params, nornir_config): class Test: @classmethod - def setup_class(cls): + def setup_class(cls) -> None: ConnectionPluginRegister.deregister_all() ConnectionPluginRegister.register("dummy", DummyConnectionPlugin) ConnectionPluginRegister.register("dummy2", DummyConnectionPlugin) diff --git a/tests/core/test_inventory.py b/tests/core/test_inventory.py index 6219b6e5..74ddf66c 100644 --- a/tests/core/test_inventory.py +++ b/tests/core/test_inventory.py @@ -397,7 +397,7 @@ def test_filtering_func(self, inv): ) assert long_names == ["dev1.group_1", "dev4.group_2", "dev6.group_3"] - def longer_than(dev, length): + def longer_than(dev, length) -> bool: return len(dev["my_var"]) > length long_names = sorted(list(inv.filter(filter_func=longer_than, length=20).hosts.keys())) diff --git a/tests/wrapper.py b/tests/wrapper.py index 84beb401..d23a85ad 100644 --- a/tests/wrapper.py +++ b/tests/wrapper.py @@ -15,7 +15,7 @@ def wrap_cli_test(output, save_output=False): """ @decorator - def run_test(func, *args, **kwargs): + def run_test(func, *args, **kwargs) -> None: stdout = StringIO() backup_stdout = sys.stdout sys.stdout = stdout