From 7342906c514e953a9dc53a3b43c30208fcdc35d5 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Tue, 5 Nov 2024 17:39:25 +0000 Subject: [PATCH] fix: abstract api request factory --- codeforlife/tests/api_request_factory.py | 119 ++++++++++++++--------- 1 file changed, 74 insertions(+), 45 deletions(-) diff --git a/codeforlife/tests/api_request_factory.py b/codeforlife/tests/api_request_factory.py index 7565c02..0758e1a 100644 --- a/codeforlife/tests/api_request_factory.py +++ b/codeforlife/tests/api_request_factory.py @@ -5,6 +5,7 @@ import typing as t +from django.contrib.auth.models import AbstractBaseUser from django.core.handlers.wsgi import WSGIRequest from rest_framework.parsers import ( FileUploadParser, @@ -14,9 +15,9 @@ ) from rest_framework.test import APIRequestFactory as _APIRequestFactory -from ..request import Request +from ..request import BaseRequest, Request -# pylint: disable-next=duplicate-code +# pylint: disable=duplicate-code if t.TYPE_CHECKING: from ..user.models import User @@ -24,40 +25,33 @@ else: AnyUser = t.TypeVar("AnyUser") +AnyBaseRequest = t.TypeVar("AnyBaseRequest", bound=BaseRequest) +AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) +# pylint: enable=duplicate-code -class APIRequestFactory(_APIRequestFactory, t.Generic[AnyUser]): - """Custom API request factory that returns DRF's Request object.""" - - def __init__(self, user_class: t.Type[AnyUser], *args, **kwargs): - super().__init__(*args, **kwargs) - self.user_class = user_class - - @classmethod - def get_user_class(cls) -> t.Type[AnyUser]: - """Get the user class. - - Returns: - The user class. - """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] - def request(self, user: t.Optional[AnyUser] = None, **kwargs): - wsgi_request = t.cast(WSGIRequest, super().request(**kwargs)) +class BaseAPIRequestFactory( + _APIRequestFactory, t.Generic[AnyBaseRequest, AnyAbstractBaseUser] +): + """Custom API request factory that returns DRF's Request object.""" - request = Request( - self.user_class, - wsgi_request, - parsers=[ - JSONParser(), - FormParser(), - MultiPartParser(), - FileUploadParser(), - ], + def _init_request(self, wsgi_request: WSGIRequest): + return t.cast( + AnyBaseRequest, + BaseRequest( + wsgi_request, + parsers=[ + JSONParser(), + FormParser(), + MultiPartParser(), + FileUploadParser(), + ], + ), ) + def request(self, user: t.Optional[AnyAbstractBaseUser] = None, **kwargs): + wsgi_request = t.cast(WSGIRequest, super().request(**kwargs)) + request = self._init_request(wsgi_request) if user: # pylint: disable-next=attribute-defined-outside-init request.user = user @@ -72,11 +66,11 @@ def generic( data: t.Optional[str] = None, content_type: t.Optional[str] = None, secure: bool = True, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): return t.cast( - Request[AnyUser], + AnyBaseRequest, super().generic( method, path or "/", @@ -92,11 +86,11 @@ def get( # type: ignore[override] self, path: t.Optional[str] = None, data: t.Any = None, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): return t.cast( - Request[AnyUser], + AnyBaseRequest, super().get( path or "/", data, @@ -113,14 +107,14 @@ def post( # type: ignore[override] # pylint: disable-next=redefined-builtin format: t.Optional[str] = None, content_type: t.Optional[str] = None, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): if format is None and content_type is None: format = "json" return t.cast( - Request[AnyUser], + AnyBaseRequest, super().post( path or "/", data, @@ -139,14 +133,14 @@ def put( # type: ignore[override] # pylint: disable-next=redefined-builtin format: t.Optional[str] = None, content_type: t.Optional[str] = None, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): if format is None and content_type is None: format = "json" return t.cast( - Request[AnyUser], + AnyBaseRequest, super().put( path or "/", data, @@ -165,14 +159,14 @@ def patch( # type: ignore[override] # pylint: disable-next=redefined-builtin format: t.Optional[str] = None, content_type: t.Optional[str] = None, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): if format is None and content_type is None: format = "json" return t.cast( - Request[AnyUser], + AnyBaseRequest, super().patch( path or "/", data, @@ -191,14 +185,14 @@ def delete( # type: ignore[override] # pylint: disable-next=redefined-builtin format: t.Optional[str] = None, content_type: t.Optional[str] = None, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): if format is None and content_type is None: format = "json" return t.cast( - Request[AnyUser], + AnyBaseRequest, super().delete( path or "/", data, @@ -217,14 +211,14 @@ def options( # type: ignore[override] # pylint: disable-next=redefined-builtin format: t.Optional[str] = None, content_type: t.Optional[str] = None, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): if format is None and content_type is None: format = "json" return t.cast( - Request[AnyUser], + AnyBaseRequest, super().options( path or "/", data or {}, @@ -234,3 +228,38 @@ def options( # type: ignore[override] **extra, ), ) + + +class APIRequestFactory( + BaseAPIRequestFactory[Request[AnyUser], AnyUser], + t.Generic[AnyUser], +): + """Custom API request factory that returns DRF's Request object.""" + + def __init__(self, user_class: t.Type[AnyUser], *args, **kwargs): + super().__init__(*args, **kwargs) + self.user_class = user_class + + @classmethod + def get_user_class(cls) -> t.Type[AnyUser]: + """Get the user class. + + Returns: + The user class. + """ + # pylint: disable-next=no-member + return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] + 0 + ] + + def _init_request(self, wsgi_request): + return Request[AnyUser]( + self.user_class, + wsgi_request, + parsers=[ + JSONParser(), + FormParser(), + MultiPartParser(), + FileUploadParser(), + ], + )