Skip to content

Commit

Permalink
Merge pull request #837 from onegreyonewhite/1.21.x
Browse files Browse the repository at this point in the history
Fix: Provide support for enums in codecs.
  • Loading branch information
JoelLefkowitz authored Mar 16, 2023
2 parents 5530022 + a00fd87 commit 353f071
Show file tree
Hide file tree
Showing 19 changed files with 55 additions and 66 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/review.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
runs-on: "ubuntu-20.04"
strategy:
matrix:
python: ["3.7", "3.8", "3.9", "3.10"]
python: ["3.7", "3.8", "3.9", "3.10", "3.11"]

steps:
- name: Checkout the source code
Expand Down
8 changes: 4 additions & 4 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Compatible with

- **Django Rest Framework**: 3.10, 3.11, 3.12, 3.13, 3.14
- **Django**: 2.2, 3.0, 3.1, 3.2, 4.0, 4.1
- **Python**: 3.6, 3.7, 3.8, 3.9, 3.10
- **Python**: 3.6, 3.7, 3.8, 3.9, 3.10, 3.11

Only the latest patch version of each ``major.minor`` series of Python, Django and Django REST Framework is supported.

Expand Down Expand Up @@ -362,23 +362,23 @@ provided out of the box - if you have ``djangorestframework-recursive`` installe
drf-extra-fields
=================

Integration with `drf-extra-fields <https://github.com/Hipo/drf-extra-fields>`_ has a problem with Base64 fields.
Integration with `drf-extra-fields <https://github.com/Hipo/drf-extra-fields>`_ has a problem with Base64 fields.
The drf-yasg will generate Base64 file or image fields as Readonly and not required. Here is a workaround code
for display the Base64 fields correctly.

.. code:: python
class PDFBase64FileField(Base64FileField):
ALLOWED_TYPES = ['pdf']
class Meta:
swagger_schema_fields = {
'type': 'string',
'title': 'File Content',
'description': 'Content of the file base64 encoded',
'read_only': False # <-- FIX
}
def get_file_extension(self, filename, decoded_file):
try:
PyPDF2.PdfFileReader(io.BytesIO(decoded_file))
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
language = 'en'

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
Expand Down
6 changes: 3 additions & 3 deletions requirements/lint.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# used by the 'lint' tox env for linting via flake8
isort>=4.2
flake8>=3.5.0
flake8-isort>=2.3
isort>=5.12
flake8>=6.0.0
flake8-isort>=6.0.0
2 changes: 1 addition & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pytest-cov>=2.6.0
pytest-xdist>=1.25.0
pytest-django>=3.4.4
datadiff==2.0.0
psycopg2-binary==2.9.4
psycopg2-binary==2.9.5
django-fake-model==0.1.4

-r testproj.txt
2 changes: 1 addition & 1 deletion requirements/tox.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# requirements for building and running tox
tox>=3.3.0
tox>=3.3.0,<4
2 changes: 1 addition & 1 deletion src/drf_yasg/app_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
]


class AppSettings(object):
class AppSettings:
"""
Stolen from Django Rest Framework, removed caching for easier testing
"""
Expand Down
2 changes: 1 addition & 1 deletion src/drf_yasg/codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _validate_swagger_spec_validator(spec):
}


class _OpenAPICodec(object):
class _OpenAPICodec:
media_type = None

def __init__(self, validators):
Expand Down
2 changes: 1 addition & 1 deletion src/drf_yasg/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def unescape_path(self, path):
return clean_path


class OpenAPISchemaGenerator(object):
class OpenAPISchemaGenerator:
"""
This class iterates over all registered API endpoints and returns an appropriate OpenAPI 2.0 compliant schema.
Method implementations shamelessly stolen and adapted from rest-framework ``SchemaGenerator``.
Expand Down
2 changes: 1 addition & 1 deletion src/drf_yasg/inspectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def call_view_method(view, method_name, fallback_attr=None, default=None):
return default


class BaseInspector(object):
class BaseInspector:
def __init__(self, view, path, method, components, request):
"""
:param rest_framework.views.APIView view: the view associated with this endpoint
Expand Down
12 changes: 5 additions & 7 deletions src/drf_yasg/inspectors/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,25 @@
import inspect
import logging
import operator
import typing
import uuid
import pkg_resources
from packaging import version
from collections import OrderedDict
from decimal import Decimal
from inspect import signature as inspect_signature

import pkg_resources
import typing
from django.core import validators
from django.db import models
from packaging import version
from rest_framework import serializers
from rest_framework.settings import api_settings as rest_framework_settings

from .base import call_view_method, FieldInspector, NotHandled, SerializerInspector
from .. import openapi
from ..errors import SwaggerGenerationError
from ..utils import (
decimal_as_float, field_value_to_representation, filter_none, get_serializer_class, get_serializer_ref_name
)
from .base import FieldInspector, NotHandled, SerializerInspector, call_view_method


drf_version = pkg_resources.get_distribution("djangorestframework").version

Expand Down Expand Up @@ -394,7 +393,7 @@ def decimal_field_type(field):
(models.TimeField, (openapi.TYPE_STRING, None)),
(models.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
(models.CharField, (openapi.TYPE_STRING, None)),
]
]

ip_format = {'ipv4': openapi.FORMAT_IPV4, 'ipv6': openapi.FORMAT_IPV6}

Expand Down Expand Up @@ -852,4 +851,3 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, **
return ref

return NotHandled

2 changes: 1 addition & 1 deletion src/drf_yasg/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .errors import SwaggerValidationError


class SwaggerExceptionMiddleware(object):
class SwaggerExceptionMiddleware:
def __init__(self, get_response):
self.get_response = get_response

Expand Down
33 changes: 15 additions & 18 deletions src/drf_yasg/openapi.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
import collections
import enum
import logging
import re
import urllib.parse as urlparse
from collections import OrderedDict
from collections import OrderedDict, abc as collections_abc

from django.urls import get_script_prefix
from django.utils.functional import Promise
from inflection import camelize

from .utils import dict_has_ordered_keys, filter_none, force_real_str

try:
from collections import abc as collections_abc
except ImportError:
collections_abc = collections
from .utils import filter_none, force_real_str

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -143,15 +138,17 @@ def _as_odict(obj, memo):
result = OrderedDict()
memo[id(obj)] = result
items = obj.items()
if not dict_has_ordered_keys(obj):
if not isinstance(obj, dict):
items = sorted(items)
for attr, val in items:
result[attr] = SwaggerDict._as_odict(val, memo)
result[SwaggerDict._as_odict(attr, memo)] = SwaggerDict._as_odict(val, memo)
return result
elif isinstance(obj, str):
return force_real_str(obj)
elif isinstance(obj, collections_abc.Iterable) and not isinstance(obj, collections_abc.Iterator):
return type(obj)(SwaggerDict._as_odict(elem, memo) for elem in obj)
elif isinstance(obj, enum.Enum):
return obj.value

return obj

Expand Down Expand Up @@ -236,7 +233,7 @@ def __init__(self, info=None, _url=None, _prefix=None, _version=None, consumes=N
security_definitions=None, security=None, paths=None, definitions=None, **extra):
"""Root Swagger object.
:param .Info info: info object
:param Info info: info object
:param str _url: URL used for setting the API host and scheme
:param str _prefix: api path prefix to use in setting basePath; this will be appended to the wsgi
SCRIPT_NAME prefix or Django's FORCE_SCRIPT_NAME if applicable
Expand Down Expand Up @@ -391,7 +388,7 @@ def __init__(self, type=None, format=None, enum=None, pattern=None, items=None,
:param str format: value format, see OpenAPI spec
:param list enum: restrict possible values
:param str pattern: pattern if type is ``string``
:param .Items items: only valid if `type` is ``array``
:param Items items: only valid if `type` is ``array``
"""
super(Items, self).__init__(**extra)
assert type is not None, "type is required!"
Expand Down Expand Up @@ -420,7 +417,7 @@ def __init__(self, name, in_, description=None, required=None, schema=None,
:param str format: value format, see OpenAPI spec
:param list enum: restrict possible values
:param str pattern: pattern if type is ``string``
:param .Items items: only valid if `type` is ``array``
:param Items items: only valid if `type` is ``array``
:param default: default value if the parameter is not provided; must conform to parameter type
"""
super(Parameter, self).__init__(**extra)
Expand Down Expand Up @@ -512,10 +509,10 @@ def __init__(self, resolver, name, scope, expected_type, ignore_unresolved=False
"""Base class for all reference types. A reference object has only one property, ``$ref``, which must be a JSON
reference to a valid object in the specification, e.g. ``#/definitions/Article`` to refer to an article model.
:param .ReferenceResolver resolver: component resolver which must contain the referenced object
:param ReferenceResolver resolver: component resolver which must contain the referenced object
:param str name: referenced object name, e.g. "Article"
:param str scope: reference scope, e.g. "definitions"
:param type[.SwaggerDict] expected_type: the expected type that will be asserted on the object found in resolver
:param type[SwaggerDict] expected_type: the expected type that will be asserted on the object found in resolver
:param bool ignore_unresolved: do not throw if the referenced object does not exist
"""
super(_Ref, self).__init__()
Expand All @@ -530,7 +527,7 @@ def __init__(self, resolver, name, scope, expected_type, ignore_unresolved=False
def resolve(self, resolver):
"""Get the object targeted by this reference from the given component resolver.
:param .ReferenceResolver resolver: component resolver which must contain the referenced object
:param ReferenceResolver resolver: component resolver which must contain the referenced object
:returns: the target object
"""
ref_match = self.ref_name_re.match(self.ref)
Expand All @@ -549,7 +546,7 @@ class SchemaRef(_Ref):
def __init__(self, resolver, schema_name, ignore_unresolved=False):
"""Adds a reference to a named Schema defined in the ``#/definitions/`` object.
:param .ReferenceResolver resolver: component resolver which must contain the definition
:param ReferenceResolver resolver: component resolver which must contain the definition
:param str schema_name: schema name
:param bool ignore_unresolved: do not throw if the referenced object does not exist
"""
Expand Down Expand Up @@ -647,7 +644,7 @@ def with_scope(self, scope):
:param str scope: target scope, must be in this resolver's `scopes`
:return: the bound resolver
:rtype: .ReferenceResolver
:rtype: ReferenceResolver
"""
assert scope in self.scopes, "unknown scope %s" % scope
ret = ReferenceResolver(force_init=True)
Expand Down
18 changes: 2 additions & 16 deletions src/drf_yasg/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import inspect
import logging
import sys
import textwrap
from collections import OrderedDict
from decimal import Decimal
Expand Down Expand Up @@ -322,8 +321,8 @@ def force_serializer_instance(serializer):


def get_serializer_class(serializer):
"""Given a ``Serializer`` class or instance, return the ``Serializer`` class. If `serializer` is not a ``Serializer``
class or instance, raises an assertion error.
"""Given a ``Serializer`` class or instance, return the ``Serializer`` class.
If `serializer` is not a ``Serializer`` class or instance, raises an assertion error.
:param serializer: serializer class or instance, or ``None``
:return: serializer class
Expand Down Expand Up @@ -505,16 +504,3 @@ def get_field_default(field):
default = serializers.empty

return default


def dict_has_ordered_keys(obj):
"""Check if a given object is a dict that maintains insertion order.
:param obj: the dict object to check
:rtype: bool
"""
if sys.version_info >= (3, 7):
# the Python 3.7 language spec says that dict must maintain insertion order.
return isinstance(obj, dict)

return isinstance(obj, OrderedDict)
2 changes: 1 addition & 1 deletion src/drf_yasg/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_schema_view(info=None, url=None, patterns=None, urlconf=None, public=Fal
generator_class=None, authentication_classes=None, permission_classes=None):
"""Create a SchemaView class with default renderers and generators.
:param .Info info: information about the API; if omitted, defaults to :ref:`DEFAULT_INFO <default-swagger-settings>`
:param Info info: information about the API; if omitted, defaults to :ref:`DEFAULT_INFO <default-swagger-settings>`
:param str url: same as :class:`.OpenAPISchemaGenerator`
:param patterns: same as :class:`.OpenAPISchemaGenerator`
:param urlconf: same as :class:`.OpenAPISchemaGenerator`
Expand Down
2 changes: 1 addition & 1 deletion testproj/todo/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
router.register(r'another', views.TodoAnotherViewSet)
router.register(r'yetanother', views.TodoYetAnotherViewSet)
router.register(r'tree', views.TodoTreeView)
router.register(r'recursive', views.TodoRecursiveView)
router.register(r'recursive', views.TodoRecursiveView, basename='todorecursivetree')
router.register(r'harvest', views.HarvestViewSet)

urlpatterns = router.urls
Expand Down
1 change: 1 addition & 0 deletions testproj/users/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class Meta:

ref_name = "UserSerializer"


class UserListQuerySerializer(serializers.Serializer):
username = serializers.CharField(help_text="this field is generated from a query_serializer", required=False)
is_staff = serializers.BooleanField(help_text="this one too!", required=False)
Expand Down
10 changes: 8 additions & 2 deletions tests/test_reference_schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import json
from collections import OrderedDict

Expand All @@ -10,6 +11,11 @@ def test_reference_schema(swagger_dict, reference_schema, compare_schemas):
compare_schemas(swagger_dict, reference_schema)


class VerisonEnum(enum.Enum):
V1 = 'v1'
V2 = 'v2'


class NoOpFieldInspector(FieldInspector):
pass

Expand Down Expand Up @@ -38,8 +44,8 @@ def set_inspectors(inspectors, setting_name):
set_inspectors([NoOpPaginatorInspector], 'DEFAULT_PAGINATOR_INSPECTORS')

generator = OpenAPISchemaGenerator(
info=openapi.Info(title="Test generator", default_version="v1"),
version="v2",
info=openapi.Info(title="Test generator", default_version=VerisonEnum.V1),
version=VerisonEnum.V2,
)
swagger = generator.get_schema(mock_schema_request, True)

Expand Down
Loading

0 comments on commit 353f071

Please sign in to comment.