Skip to content

Commit

Permalink
fix: abstract api request factory
Browse files Browse the repository at this point in the history
  • Loading branch information
SKairinos committed Nov 5, 2024
1 parent a3fc0c7 commit 7342906
Showing 1 changed file with 74 additions and 45 deletions.
119 changes: 74 additions & 45 deletions codeforlife/tests/api_request_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import typing as t

from django.contrib.auth.models import AbstractBaseUser
from django.core.handlers.wsgi import WSGIRequest
from rest_framework.parsers import (
FileUploadParser,
Expand All @@ -14,50 +15,43 @@
)
from rest_framework.test import APIRequestFactory as _APIRequestFactory

from ..request import Request
from ..request import BaseRequest, Request

# pylint: disable-next=duplicate-code
# pylint: disable=duplicate-code
if t.TYPE_CHECKING:
from ..user.models import User

AnyUser = t.TypeVar("AnyUser", bound=User)
else:
AnyUser = t.TypeVar("AnyUser")

AnyBaseRequest = t.TypeVar("AnyBaseRequest", bound=BaseRequest)
AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser)
# pylint: enable=duplicate-code

class APIRequestFactory(_APIRequestFactory, t.Generic[AnyUser]):
"""Custom API request factory that returns DRF's Request object."""

def __init__(self, user_class: t.Type[AnyUser], *args, **kwargs):
super().__init__(*args, **kwargs)
self.user_class = user_class

@classmethod
def get_user_class(cls) -> t.Type[AnyUser]:
"""Get the user class.
Returns:
The user class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]

def request(self, user: t.Optional[AnyUser] = None, **kwargs):
wsgi_request = t.cast(WSGIRequest, super().request(**kwargs))
class BaseAPIRequestFactory(
_APIRequestFactory, t.Generic[AnyBaseRequest, AnyAbstractBaseUser]
):
"""Custom API request factory that returns DRF's Request object."""

request = Request(
self.user_class,
wsgi_request,
parsers=[
JSONParser(),
FormParser(),
MultiPartParser(),
FileUploadParser(),
],
def _init_request(self, wsgi_request: WSGIRequest):
return t.cast(
AnyBaseRequest,
BaseRequest(
wsgi_request,
parsers=[
JSONParser(),
FormParser(),
MultiPartParser(),
FileUploadParser(),
],
),
)

def request(self, user: t.Optional[AnyAbstractBaseUser] = None, **kwargs):
wsgi_request = t.cast(WSGIRequest, super().request(**kwargs))
request = self._init_request(wsgi_request)
if user:
# pylint: disable-next=attribute-defined-outside-init
request.user = user
Expand All @@ -72,11 +66,11 @@ def generic(
data: t.Optional[str] = None,
content_type: t.Optional[str] = None,
secure: bool = True,
user: t.Optional[AnyUser] = None,
user: t.Optional[AnyAbstractBaseUser] = None,
**extra
):
return t.cast(
Request[AnyUser],
AnyBaseRequest,
super().generic(
method,
path or "/",
Expand All @@ -92,11 +86,11 @@ def get( # type: ignore[override]
self,
path: t.Optional[str] = None,
data: t.Any = None,
user: t.Optional[AnyUser] = None,
user: t.Optional[AnyAbstractBaseUser] = None,
**extra
):
return t.cast(
Request[AnyUser],
AnyBaseRequest,
super().get(
path or "/",
data,
Expand All @@ -113,14 +107,14 @@ def post( # type: ignore[override]
# pylint: disable-next=redefined-builtin
format: t.Optional[str] = None,
content_type: t.Optional[str] = None,
user: t.Optional[AnyUser] = None,
user: t.Optional[AnyAbstractBaseUser] = None,
**extra
):
if format is None and content_type is None:
format = "json"

return t.cast(
Request[AnyUser],
AnyBaseRequest,
super().post(
path or "/",
data,
Expand All @@ -139,14 +133,14 @@ def put( # type: ignore[override]
# pylint: disable-next=redefined-builtin
format: t.Optional[str] = None,
content_type: t.Optional[str] = None,
user: t.Optional[AnyUser] = None,
user: t.Optional[AnyAbstractBaseUser] = None,
**extra
):
if format is None and content_type is None:
format = "json"

return t.cast(
Request[AnyUser],
AnyBaseRequest,
super().put(
path or "/",
data,
Expand All @@ -165,14 +159,14 @@ def patch( # type: ignore[override]
# pylint: disable-next=redefined-builtin
format: t.Optional[str] = None,
content_type: t.Optional[str] = None,
user: t.Optional[AnyUser] = None,
user: t.Optional[AnyAbstractBaseUser] = None,
**extra
):
if format is None and content_type is None:
format = "json"

return t.cast(
Request[AnyUser],
AnyBaseRequest,
super().patch(
path or "/",
data,
Expand All @@ -191,14 +185,14 @@ def delete( # type: ignore[override]
# pylint: disable-next=redefined-builtin
format: t.Optional[str] = None,
content_type: t.Optional[str] = None,
user: t.Optional[AnyUser] = None,
user: t.Optional[AnyAbstractBaseUser] = None,
**extra
):
if format is None and content_type is None:
format = "json"

return t.cast(
Request[AnyUser],
AnyBaseRequest,
super().delete(
path or "/",
data,
Expand All @@ -217,14 +211,14 @@ def options( # type: ignore[override]
# pylint: disable-next=redefined-builtin
format: t.Optional[str] = None,
content_type: t.Optional[str] = None,
user: t.Optional[AnyUser] = None,
user: t.Optional[AnyAbstractBaseUser] = None,
**extra
):
if format is None and content_type is None:
format = "json"

return t.cast(
Request[AnyUser],
AnyBaseRequest,
super().options(
path or "/",
data or {},
Expand All @@ -234,3 +228,38 @@ def options( # type: ignore[override]
**extra,
),
)


class APIRequestFactory(
BaseAPIRequestFactory[Request[AnyUser], AnyUser],
t.Generic[AnyUser],
):
"""Custom API request factory that returns DRF's Request object."""

def __init__(self, user_class: t.Type[AnyUser], *args, **kwargs):
super().__init__(*args, **kwargs)
self.user_class = user_class

@classmethod
def get_user_class(cls) -> t.Type[AnyUser]:
"""Get the user class.
Returns:
The user class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]

def _init_request(self, wsgi_request):
return Request[AnyUser](
self.user_class,
wsgi_request,
parsers=[
JSONParser(),
FormParser(),
MultiPartParser(),
FileUploadParser(),
],
)

0 comments on commit 7342906

Please sign in to comment.