From e0dfa6c0c3add080e22617d88811be0834c2b4cc Mon Sep 17 00:00:00 2001 From: Kirk Byers Date: Mon, 16 Dec 2024 10:32:05 -0800 Subject: [PATCH] Use better for Generic type enforcement and mypy --- nornir/core/configuration.py | 38 +++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/nornir/core/configuration.py b/nornir/core/configuration.py index 543d2bc1..29548f14 100644 --- a/nornir/core/configuration.py +++ b/nornir/core/configuration.py @@ -5,7 +5,7 @@ import sys import warnings from pathlib import Path -from typing import Any, Dict, List, Optional, Type, TypeVar +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar import ruamel.yaml @@ -16,7 +16,7 @@ T = TypeVar("T") -class Parameter: +class Parameter(Generic[T]): def __init__( self, envvar: str, @@ -47,6 +47,10 @@ def resolve(self, value: Optional[T]) -> T: if v is None: v = self.default + + if not isinstance(v, self.type): + raise TypeError(f"Expected type {self.type}, got {type(v)}") + return v @@ -54,7 +58,7 @@ class SSHConfig: __slots__ = ("config_file",) class Parameters: - config_file = Parameter(default=DEFAULT_SSH_CONFIG, envvar="NORNIR_SSH_CONFIG_FILE") + config_file = Parameter[str](default=DEFAULT_SSH_CONFIG, envvar="NORNIR_SSH_CONFIG_FILE") def __init__(self, config_file: Optional[str] = None) -> None: self.config_file = self.Parameters.config_file.resolve(config_file) @@ -67,10 +71,12 @@ class InventoryConfig: __slots__ = "options", "plugin", "transform_function", "transform_function_options" class Parameters: - plugin = Parameter(typ=str, default="SimpleInventory", envvar="NORNIR_INVENTORY_PLUGIN") - options = Parameter(default={}, envvar="NORNIR_INVENTORY_OPTIONS") - transform_function = Parameter(typ=str, envvar="NORNIR_INVENTORY_TRANSFORM_FUNCTION") - transform_function_options = Parameter( + plugin = Parameter[str]( + typ=str, default="SimpleInventory", envvar="NORNIR_INVENTORY_PLUGIN" + ) + options = Parameter[Dict[str, Any]](default={}, envvar="NORNIR_INVENTORY_OPTIONS") + transform_function = Parameter[str](typ=str, envvar="NORNIR_INVENTORY_TRANSFORM_FUNCTION") + transform_function_options = Parameter[Dict[str, Any]]( default={}, envvar="NORNIR_INVENTORY_TRANSFORM_FUNCTION_OPTIONS" ) @@ -101,15 +107,15 @@ class LoggingConfig: __slots__ = "enabled", "format", "level", "log_file", "loggers", "to_console" class Parameters: - enabled = Parameter(default=True, envvar="NORNIR_LOGGING_ENABLED") - level = Parameter(default="INFO", envvar="NORNIR_LOGGING_LEVEL") - log_file = Parameter(default="nornir.log", envvar="NORNIR_LOGGING_LOG_FILE") - format = Parameter( + enabled = Parameter[bool](default=True, envvar="NORNIR_LOGGING_ENABLED") + level = Parameter[str](default="INFO", envvar="NORNIR_LOGGING_LEVEL") + log_file = Parameter[str](default="nornir.log", envvar="NORNIR_LOGGING_LOG_FILE") + format = Parameter[str]( default="%(asctime)s - %(name)12s - %(levelname)8s - %(funcName)10s() - %(message)s", envvar="NORNIR_LOGGING_FORMAT", ) - to_console = Parameter(default=False, envvar="NORNIR_LOGGING_TO_CONSOLE") - loggers = Parameter(default=["nornir"], envvar="NORNIR_LOGGING_LOGGERS") + to_console = Parameter[bool](default=False, envvar="NORNIR_LOGGING_TO_CONSOLE") + loggers = Parameter[List[str]](default=["nornir"], envvar="NORNIR_LOGGING_LOGGERS") def __init__( self, @@ -194,8 +200,8 @@ class RunnerConfig: __slots__ = ("options", "plugin") class Parameters: - plugin = Parameter(default="threaded", envvar="NORNIR_RUNNER_PLUGIN") - options = Parameter(default={}, envvar="NORNIR_RUNNER_OPTIONS") + plugin = Parameter[str](default="threaded", envvar="NORNIR_RUNNER_PLUGIN") + options = Parameter[Dict[str, Any]](default={}, envvar="NORNIR_RUNNER_OPTIONS") def __init__( self, plugin: Optional[str] = None, options: Optional[Dict[str, Any]] = None @@ -214,7 +220,7 @@ class CoreConfig: __slots__ = ("raise_on_error",) class Parameters: - raise_on_error = Parameter(default=False, envvar="NORNIR_CORE_RAISE_ON_ERROR") + raise_on_error = Parameter[bool](default=False, envvar="NORNIR_CORE_RAISE_ON_ERROR") def __init__(self, raise_on_error: Optional[bool] = None) -> None: self.raise_on_error = self.Parameters.raise_on_error.resolve(raise_on_error)