Skip to content
This repository has been archived by the owner on Nov 19, 2023. It is now read-only.

Commit

Permalink
refactor: Improve type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
sondrelg committed Dec 29, 2021
1 parent 8681c40 commit a0926f5
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 26 deletions.
2 changes: 1 addition & 1 deletion manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys


def main():
def main() -> None:
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "test_project.settings")
try:
from django.core.management import execute_from_command_line
Expand Down
16 changes: 9 additions & 7 deletions openapi_tester/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from rest_framework.views import APIView

if TYPE_CHECKING:
from typing import Callable
from typing import Any, Callable
from urllib.parse import ParseResult

from django.urls import ResolverMatch
Expand Down Expand Up @@ -74,7 +74,7 @@ def get_schema(self) -> dict:
return self.get_schema()

def de_reference_schema(self, schema: dict) -> dict:
url = schema["basePath"] if "basePath" in schema else self.base_path
url = schema.get("basePath", self.base_path)
recursion_handler = handle_recursion_limit(schema)
resolver = RefResolver(
schema,
Expand Down Expand Up @@ -138,7 +138,7 @@ def resolve_path(self, endpoint_path: str, method: str) -> tuple[str, ResolverMa
for key, value in reversed(list(resolved_route.kwargs.items())):
index = path.rfind(str(value))
path = f"{path[:index]}{{{key}}}{path[index + len(str(value)):]}"
if "{pk}" in path and api_settings.SCHEMA_COERCE_PATH_PK:
if "{pk}" in path and api_settings.SCHEMA_COERCE_PATH_PK: # noqa: FS003
path, resolved_route = self.handle_pk_parameter(
resolved_route=resolved_route, path=path, method=method
)
Expand Down Expand Up @@ -182,7 +182,7 @@ def load_schema(self) -> dict:
Loads generated schema from drf-yasg and returns it as a dict.
"""
odict_schema = self.schema_generator.get_schema(None, True)
return loads(dumps(odict_schema.as_odict()))
return cast(dict, loads(dumps(odict_schema.as_odict())))

def resolve_path(self, endpoint_path: str, method: str) -> tuple[str, ResolverMatch]:
de_parameterized_path, resolved_path = super().resolve_path(endpoint_path=endpoint_path, method=method)
Expand All @@ -206,7 +206,7 @@ def load_schema(self) -> dict:
"""
Loads generated schema from drf_spectacular and returns it as a dict.
"""
return loads(dumps(self.schema_generator.get_schema(public=True)))
return cast(dict, loads(dumps(self.schema_generator.get_schema(public=True))))

def resolve_path(self, endpoint_path: str, method: str) -> tuple[str, ResolverMatch]:
from drf_spectacular.settings import spectacular_settings
Expand All @@ -227,7 +227,7 @@ def __init__(self, path: str, field_key_map: dict[str, str] | None = None):
super().__init__(field_key_map=field_key_map)
self.path = path if not isinstance(path, pathlib.PosixPath) else str(path)

def load_schema(self) -> dict:
def load_schema(self) -> dict[str, Any]:
"""
Loads a static OpenAPI schema from file, and parses it to a python dict.
Expand All @@ -236,4 +236,6 @@ def load_schema(self) -> dict:
"""
with open(self.path, encoding="utf-8") as file:
content = file.read()
return json.loads(content) if ".json" in self.path else yaml.load(content, Loader=yaml.FullLoader)
return cast(
dict, json.loads(content) if ".json" in self.path else yaml.load(content, Loader=yaml.FullLoader)
)
33 changes: 18 additions & 15 deletions openapi_tester/schema_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations

from itertools import chain
from typing import TYPE_CHECKING, Callable, List, cast
from typing import TYPE_CHECKING, Any, Callable, List, Optional, cast

from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
Expand Down Expand Up @@ -38,7 +38,6 @@
)

if TYPE_CHECKING:
from typing import Any

from rest_framework.response import Response

Expand Down Expand Up @@ -79,7 +78,7 @@ def __init__(
raise ImproperlyConfigured(INIT_ERROR)

@staticmethod
def get_key_value(schema: dict, key: str, error_addon: str = "") -> dict:
def get_key_value(schema: dict[str, dict], key: str, error_addon: str = "") -> dict:
"""
Returns the value of a given key
"""
Expand All @@ -91,7 +90,7 @@ def get_key_value(schema: dict, key: str, error_addon: str = "") -> dict:
) from e

@staticmethod
def get_status_code(schema: dict, status_code: str | int, error_addon: str = "") -> dict:
def get_status_code(schema: dict[str | int, dict], status_code: str | int, error_addon: str = "") -> dict:
"""
Returns the status code section of a schema, handles both str and int status codes
"""
Expand All @@ -104,7 +103,7 @@ def get_status_code(schema: dict, status_code: str | int, error_addon: str = "")
)

@staticmethod
def get_schema_type(schema: dict) -> str | None:
def get_schema_type(schema: dict[str, str]) -> str | None:
if "type" in schema:
return schema["type"]
if "properties" in schema or "additionalProperties" in schema:
Expand Down Expand Up @@ -132,14 +131,16 @@ def get_response_schema_section(self, response: Response) -> dict[str, Any]:
method_object = self.get_key_value(
route_object,
response_method,
f"\n\nUndocumented method: {response_method}.\n\nDocumented methods: {[method.lower() for method in route_object.keys() if method.lower() != 'parameters']}.",
f"\n\nUndocumented method: {response_method}.\n\nDocumented methods: "
f"{[method.lower() for method in route_object.keys() if method.lower() != 'parameters']}.",
)

responses_object = self.get_key_value(method_object, "responses")
status_code_object = self.get_status_code(
responses_object,
response.status_code,
f"\n\nUndocumented status code: {response.status_code}.\n\nDocumented status codes: {list(responses_object.keys())}. ",
f"\n\nUndocumented status code: {response.status_code}.\n\n"
f"Documented status codes: {list(responses_object.keys())}. ",
)

if "openapi" not in schema: # pylint: disable=E1135
Expand All @@ -155,20 +156,22 @@ def get_response_schema_section(self, response: Response) -> dict[str, Any]:
json_object = self.get_key_value(
content_object,
"application/json",
f"\n\nNo `application/json` responses documented for method: {response_method}, path: {parameterized_path}",
f"\n\nNo `application/json` responses documented for method: "
f"{response_method}, path: {parameterized_path}",
)
return self.get_key_value(json_object, "schema")

if response.json():
raise UndocumentedSchemaSectionError(
UNDOCUMENTED_SCHEMA_SECTION_ERROR.format(
key="content",
error_addon=f"\n\nNo `content` defined for this response: {response_method}, path: {parameterized_path}",
error_addon=f"\n\nNo `content` defined for this response: "
f"{response_method}, path: {parameterized_path}",
)
)
return {}

def handle_one_of(self, schema_section: dict, data: Any, reference: str, **kwargs: Any):
def handle_one_of(self, schema_section: dict, data: Any, reference: str, **kwargs: Any) -> None:
matches = 0
passed_schema_section_formats = set()
for option in schema_section["oneOf"]:
Expand All @@ -186,7 +189,7 @@ def handle_one_of(self, schema_section: dict, data: Any, reference: str, **kwarg
if matches != 1:
raise DocumentationError(f"{VALIDATE_ONE_OF_ERROR.format(matches=matches)}\n\nReference: {reference}.oneOf")

def handle_any_of(self, schema_section: dict, data: Any, reference: str, **kwargs: Any):
def handle_any_of(self, schema_section: dict, data: Any, reference: str, **kwargs: Any) -> None:
any_of: list[dict[str, Any]] = schema_section.get("anyOf", [])
for schema in chain(any_of, lazy_combinations(any_of)):
try:
Expand Down Expand Up @@ -257,7 +260,7 @@ def test_schema_section(
if not schema_section_type:
return
combined_validators = cast(
List[Callable],
List[Callable[[dict, Any], Optional[str]]],
[
validate_type,
validate_format,
Expand Down Expand Up @@ -349,7 +352,7 @@ def test_openapi_object(
ignore_case=ignore_case,
)

def test_openapi_array(self, schema_section: dict, data: dict, reference: str, **kwargs: Any) -> None:
def test_openapi_array(self, schema_section: dict[str, Any], data: dict, reference: str, **kwargs: Any) -> None:
for datum in data:
self.test_schema_section(
# the items keyword is required in arrays
Expand All @@ -364,8 +367,8 @@ def validate_response(
response: Response,
case_tester: Callable[[str], None] | None = None,
ignore_case: list[str] | None = None,
validators: list[Callable[[dict, Any], str | None]] | None = None,
):
validators: list[Callable[[dict[str, Any], Any], str | None]] | None = None,
) -> None:
"""
Verifies that an OpenAPI schema definition matches an API response.
Expand Down
4 changes: 2 additions & 2 deletions openapi_tester/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
from typing import Any, Callable


def create_validator(validation_fn: Callable, wrap_as_validator: bool = False) -> Callable:
def wrapped(value: Any):
def create_validator(validation_fn: Callable, wrap_as_validator: bool = False) -> Callable[[Any], bool]:
def wrapped(value: Any) -> bool:
try:
return bool(validation_fn(value)) or not wrap_as_validator
except (ValueError, ValidationError):
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ per-file-ignores =
test_project/*:FS003

[mypy]
python_version = 3.10
show_column_numbers = True
show_error_context = False
ignore_missing_imports = True
Expand Down

0 comments on commit a0926f5

Please sign in to comment.