Skip to content

Commit

Permalink
improve mypy typing #600
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed Dec 10, 2023
1 parent 4ef322c commit bbf1bc1
Show file tree
Hide file tree
Showing 14 changed files with 218 additions and 146 deletions.
3 changes: 2 additions & 1 deletion drf_spectacular/contrib/rest_polymorphic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from drf_spectacular.drainage import warn
from drf_spectacular.extensions import OpenApiSerializerExtension
from drf_spectacular.plumbing import (
ResolvedComponent, build_basic_type, build_object_type, is_patched_serializer, warn,
ResolvedComponent, build_basic_type, build_object_type, is_patched_serializer,
)
from drf_spectacular.settings import spectacular_settings
from drf_spectacular.types import OpenApiTypes
Expand Down
46 changes: 30 additions & 16 deletions drf_spectacular/drainage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,21 @@
import inspect
import sys
from collections import defaultdict
from typing import DefaultDict, List, Optional, Tuple
from typing import Any, Callable, DefaultDict, List, Optional, Tuple, TypeVar

if sys.version_info >= (3, 8):
from typing import ( # type: ignore[attr-defined] # noqa: F401
Final, Literal, TypedDict, _TypedDictMeta,
)
else:
from typing_extensions import Final, Literal, TypedDict, _TypedDictMeta # noqa: F401

if sys.version_info >= (3, 10):
from typing import TypeGuard # noqa: F401
else:
from typing_extensions import TypeGuard # noqa: F401

F = TypeVar('F', bound=Callable[..., Any])


class GeneratorStats:
Expand Down Expand Up @@ -37,20 +51,20 @@ def silence(self):
finally:
self.silent = tmp

def reset(self):
def reset(self) -> None:
self._warn_cache.clear()
self._error_cache.clear()

def enable_color(self):
def enable_color(self) -> None:
self._blue = '\033[0;34m'
self._red = '\033[0;31m'
self._yellow = '\033[0;33m'
self._clear = '\033[0m'

def enable_trace_lineno(self):
def enable_trace_lineno(self) -> None:
self._trace_lineno = True

def _get_current_trace(self):
def _get_current_trace(self) -> Tuple[Optional[str], str]:
source_locations = [t for t in self._traces if t[0]]
if source_locations:
sourcefile, lineno, _ = source_locations[-1]
Expand All @@ -60,7 +74,7 @@ def _get_current_trace(self):
breadcrumbs = ' > '.join(t[2] for t in self._traces)
return source_location, breadcrumbs

def emit(self, msg, severity):
def emit(self, msg: str, severity: str) -> None:
assert severity in ['warning', 'error']
cache = self._warn_cache if severity == 'warning' else self._error_cache

Expand All @@ -75,7 +89,7 @@ def emit(self, msg, severity):
print(msg, file=sys.stderr)
cache[msg] += 1

def emit_summary(self):
def emit_summary(self) -> None:
if not self.silent and (self._warn_cache or self._error_cache):
print(
f'\nSchema generation summary:\n'
Expand All @@ -88,7 +102,7 @@ def emit_summary(self):
GENERATOR_STATS = GeneratorStats()


def warn(msg, delayed=None):
def warn(msg: str, delayed: Any = None) -> None:
if delayed:
warnings = get_override(delayed, 'warnings', [])
warnings.append(msg)
Expand All @@ -97,7 +111,7 @@ def warn(msg, delayed=None):
GENERATOR_STATS.emit(msg, 'warning')


def error(msg, delayed=None):
def error(msg: str, delayed: Any = None) -> None:
if delayed:
errors = get_override(delayed, 'errors', [])
errors.append(msg)
Expand All @@ -106,7 +120,7 @@ def error(msg, delayed=None):
GENERATOR_STATS.emit(msg, 'error')


def reset_generator_stats():
def reset_generator_stats() -> None:
GENERATOR_STATS.reset()


Expand Down Expand Up @@ -136,7 +150,7 @@ def _get_source_location(obj):
return sourcefile, lineno


def has_override(obj, prop):
def has_override(obj: Any, prop: str) -> bool:
if isinstance(obj, functools.partial):
obj = obj.func
if not hasattr(obj, '_spectacular_annotation'):
Expand All @@ -146,15 +160,15 @@ def has_override(obj, prop):
return True


def get_override(obj, prop, default=None):
def get_override(obj: Any, prop: str, default: Any = None) -> Any:
if isinstance(obj, functools.partial):
obj = obj.func
if not has_override(obj, prop):
return default
return obj._spectacular_annotation[prop]


def set_override(obj, prop, value):
def set_override(obj: Any, prop: str, value: Any) -> Any:
if not hasattr(obj, '_spectacular_annotation'):
obj._spectacular_annotation = {}
elif '_spectacular_annotation' not in obj.__dict__:
Expand All @@ -163,7 +177,7 @@ def set_override(obj, prop, value):
return obj


def get_view_method_names(view, schema=None):
def get_view_method_names(view, schema=None) -> List[str]:
schema = schema or view.schema
return [
item for item in dir(view) if callable(getattr(view, item)) and (
Expand Down Expand Up @@ -201,6 +215,6 @@ def wrapped_method(self, request, *args, **kwargs):
return wrapped_method


def cache(user_function):
def cache(user_function: F) -> F:
""" simple polyfill for python < 3.9 """
return functools.lru_cache(maxsize=None)(user_function)
return functools.lru_cache(maxsize=None)(user_function) # type: ignore
21 changes: 12 additions & 9 deletions drf_spectacular/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from drf_spectacular.openapi import AutoSchema


_SchemaType = Dict[str, Any]


class OpenApiAuthenticationExtension(OpenApiGeneratorExtension['OpenApiAuthenticationExtension']):
"""
Extension for specifying authentication schemes.
Expand All @@ -29,7 +32,7 @@ class OpenApiAuthenticationExtension(OpenApiGeneratorExtension['OpenApiAuthentic
``get_security_definition()`` is expected to return a valid `OpenAPI security scheme object
<https://spec.openapis.org/oas/v3.0.3#securitySchemeObject>`_
"""
_registry: List['OpenApiAuthenticationExtension'] = []
_registry: List[Type['OpenApiAuthenticationExtension']] = []

name: Union[str, List[str]]

Expand All @@ -43,7 +46,7 @@ def get_security_requirement(
return {name: [] for name in self.name}

@abstractmethod
def get_security_definition(self, auto_schema: 'AutoSchema') -> Union[dict, List[dict]]:
def get_security_definition(self, auto_schema: 'AutoSchema') -> Union[_SchemaType, List[_SchemaType]]:
pass # pragma: no cover


Expand All @@ -59,13 +62,13 @@ class OpenApiSerializerExtension(OpenApiGeneratorExtension['OpenApiSerializerExt
``map_serializer()`` is expected to return a valid `OpenAPI schema object
<https://spec.openapis.org/oas/v3.0.3#schemaObject>`_.
"""
_registry: List['OpenApiSerializerExtension'] = []
_registry: List[Type['OpenApiSerializerExtension']] = []

def get_name(self, auto_schema: 'AutoSchema', direction: Direction) -> Optional[str]:
""" return str for overriding default name extraction """
return None

def map_serializer(self, auto_schema: 'AutoSchema', direction: Direction):
def map_serializer(self, auto_schema: 'AutoSchema', direction: Direction) -> _SchemaType:
""" override for customized serializer mapping """
return auto_schema._map_serializer(self.target_class, direction, bypass_extensions=True)

Expand All @@ -82,14 +85,14 @@ class OpenApiSerializerFieldExtension(OpenApiGeneratorExtension['OpenApiSerializ
``map_serializer_field()`` is expected to return a valid `OpenAPI schema object
<https://spec.openapis.org/oas/v3.0.3#schemaObject>`_.
"""
_registry: List['OpenApiSerializerFieldExtension'] = []
_registry: List[Type['OpenApiSerializerFieldExtension']] = []

def get_name(self) -> Optional[str]:
""" return str for breaking out field schema into separate named component """
return None

@abstractmethod
def map_serializer_field(self, auto_schema: 'AutoSchema', direction: Direction):
def map_serializer_field(self, auto_schema: 'AutoSchema', direction: Direction) -> _SchemaType:
""" override for customized serializer field mapping """
pass # pragma: no cover

Expand All @@ -102,7 +105,7 @@ class OpenApiViewExtension(OpenApiGeneratorExtension['OpenApiViewExtension']):
``ViewSet`` et al.). The discovered original view instance can be accessed with
``self.target`` and be subclassed if desired.
"""
_registry: List['OpenApiViewExtension'] = []
_registry: List[Type['OpenApiViewExtension']] = []

@classmethod
def _load_class(cls):
Expand All @@ -129,8 +132,8 @@ class OpenApiFilterExtension(OpenApiGeneratorExtension['OpenApiFilterExtension']
Using ``drf_spectacular.plumbing.build_parameter_type`` is recommended to generate
the appropriate raw dict objects.
"""
_registry: List['OpenApiFilterExtension'] = []
_registry: List[Type['OpenApiFilterExtension']] = []

@abstractmethod
def get_schema_operation_parameters(self, auto_schema: 'AutoSchema', *args, **kwargs) -> List[dict]:
def get_schema_operation_parameters(self, auto_schema: 'AutoSchema', *args, **kwargs) -> List[_SchemaType]:
pass # pragma: no cover
12 changes: 7 additions & 5 deletions drf_spectacular/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@

from django.urls import URLPattern, URLResolver
from rest_framework import views, viewsets
from rest_framework.schemas.generators import BaseSchemaGenerator # type: ignore
from rest_framework.schemas.generators import BaseSchemaGenerator
from rest_framework.schemas.generators import EndpointEnumerator as BaseEndpointEnumerator
from rest_framework.settings import api_settings

from drf_spectacular.drainage import add_trace_message, get_override, reset_generator_stats
from drf_spectacular.drainage import (
add_trace_message, error, get_override, reset_generator_stats, warn,
)
from drf_spectacular.extensions import OpenApiViewExtension
from drf_spectacular.openapi import AutoSchema
from drf_spectacular.plumbing import (
ComponentRegistry, alpha_operation_sorter, build_root_object, camelize_operation, error,
get_class, is_versioning_supported, modify_for_versioning, normalize_result_object,
operation_matches_version, sanitize_result_object, warn,
ComponentRegistry, alpha_operation_sorter, build_root_object, camelize_operation, get_class,
is_versioning_supported, modify_for_versioning, normalize_result_object,
operation_matches_version, sanitize_result_object,
)
from drf_spectacular.settings import spectacular_settings

Expand Down
3 changes: 2 additions & 1 deletion drf_spectacular/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from inflection import camelize
from rest_framework.settings import api_settings

from drf_spectacular.drainage import warn
from drf_spectacular.plumbing import (
ResolvedComponent, list_hash, load_enum_name_overrides, safe_ref, warn,
ResolvedComponent, list_hash, load_enum_name_overrides, safe_ref,
)
from drf_spectacular.settings import spectacular_settings

Expand Down
Loading

0 comments on commit bbf1bc1

Please sign in to comment.