diff --git a/README.rst b/README.rst index b32ac4fd..d3bc88c9 100644 --- a/README.rst +++ b/README.rst @@ -47,6 +47,7 @@ Features - `djangorestframework-recursive `_ - `djangorestframework-dataclasses `_ - `django-rest-framework-gis `_ + - `Pydantic (>=2.0) `_ For more information visit the `documentation `_. diff --git a/drf_spectacular/contrib/__init__.py b/drf_spectacular/contrib/__init__.py index 6fd2abf2..90e8b8f6 100644 --- a/drf_spectacular/contrib/__init__.py +++ b/drf_spectacular/contrib/__init__.py @@ -10,4 +10,5 @@ 'django_filters', 'rest_framework_recursive', 'rest_framework_gis', + 'pydantic', ] diff --git a/drf_spectacular/contrib/pydantic.py b/drf_spectacular/contrib/pydantic.py new file mode 100644 index 00000000..2430136f --- /dev/null +++ b/drf_spectacular/contrib/pydantic.py @@ -0,0 +1,49 @@ +from drf_spectacular.drainage import set_override, warn +from drf_spectacular.extensions import OpenApiSerializerExtension +from drf_spectacular.plumbing import ResolvedComponent, build_basic_type +from drf_spectacular.types import OpenApiTypes + + +class PydanticExtension(OpenApiSerializerExtension): + """ + Allows using pydantic models on @extend_schema(request=..., response=...) to + describe your API. + + We only have partial support for pydantic's version of dataclass, due to the way they + are designed. The outermost class (the @extend_schema argument) has to be a subclass + of pydantic.BaseModel. Inside this outermost BaseModel, any combination of dataclass + and BaseModel can be used. + """ + + target_class = "pydantic.BaseModel" + match_subclasses = True + + def get_name(self, auto_schema, direction): + # due to the fact that it is complicated to pull out every field member BaseModel class + # of the entry model, we simply use the class name as string for object. This hack may + # create false positive warnings, so turn it off. However, this may suppress correct + # warnings involving the entry class. + set_override(self.target, 'suppress_collision_warning', True) + return self.target.__name__ + + def map_serializer(self, auto_schema, direction): + # let pydantic generate a JSON schema + try: + from pydantic.json_schema import model_json_schema + except ImportError: + warn("Only pydantic >= 2 is supported. defaulting to generic object.") + return build_basic_type(OpenApiTypes.OBJECT) + + schema = model_json_schema(self.target, ref_template="#/components/schemas/{model}") + + # pull out potential sub-schemas and put them into component section + for sub_name, sub_schema in schema.pop("$defs", {}).items(): + component = ResolvedComponent( + name=sub_name, + type=ResolvedComponent.SCHEMA, + object=sub_name, + schema=sub_schema, + ) + auto_schema.registry.register_on_missing(component) + + return schema diff --git a/requirements/optionals.txt b/requirements/optionals.txt index 7ce28e34..a8ef61a7 100644 --- a/requirements/optionals.txt +++ b/requirements/optionals.txt @@ -13,3 +13,4 @@ djangorestframework-recursive>=0.1.2 drf-spectacular-sidecar djangorestframework-dataclasses>=1.0.0; python_version >= '3.7' djangorestframework-gis>=1.0.0 +pydantic>=2,<3; python_version >= '3.7' diff --git a/tests/contrib/test_pydantic.py b/tests/contrib/test_pydantic.py new file mode 100644 index 00000000..977939ae --- /dev/null +++ b/tests/contrib/test_pydantic.py @@ -0,0 +1,45 @@ +import sys +from typing import List + +import pytest +from rest_framework.views import APIView + +from drf_spectacular.utils import extend_schema +from tests import assert_schema, generate_schema + +try: + from pydantic import BaseModel + from pydantic.dataclasses import dataclass +except ImportError: + class BaseModel: + pass + + def dataclass(f): + return f + + +@dataclass +class C: + id: int + + +class B(BaseModel): + id: int + c: List[C] + + +class A(BaseModel): + id: int + b: B + + +@pytest.mark.contrib('pydantic') +@pytest.mark.skipif(sys.version_info < (3, 7), reason='python 3.7+ is required by package') +def test_pydantic_decoration(no_warnings): + class XAPIView(APIView): + @extend_schema(request=A, responses=B) + def post(self, request): + pass # pragma: no cover + + schema = generate_schema('/x', view=XAPIView) + assert_schema(schema, 'tests/contrib/test_pydantic.yml') diff --git a/tests/contrib/test_pydantic.yml b/tests/contrib/test_pydantic.yml new file mode 100644 index 00000000..ac62fdc2 --- /dev/null +++ b/tests/contrib/test_pydantic.yml @@ -0,0 +1,79 @@ +openapi: 3.0.3 +info: + title: '' + version: 0.0.0 +paths: + /x: + post: + operationId: x_create + tags: + - x + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/A' + application/x-www-form-urlencoded: + schema: + $ref: '#/components/schemas/A' + multipart/form-data: + schema: + $ref: '#/components/schemas/A' + required: true + security: + - cookieAuth: [] + - basicAuth: [] + - {} + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/B' + description: '' +components: + schemas: + A: + properties: + id: + title: Id + type: integer + b: + $ref: '#/components/schemas/B' + required: + - id + - b + title: A + type: object + B: + properties: + id: + title: Id + type: integer + c: + items: + $ref: '#/components/schemas/C' + title: C + type: array + required: + - id + - c + title: B + type: object + C: + properties: + id: + title: Id + type: integer + required: + - id + title: C + type: object + securitySchemes: + basicAuth: + type: http + scheme: basic + cookieAuth: + type: apiKey + in: cookie + name: sessionid