Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SKairinos committed Sep 13, 2024
1 parent debfa28 commit fd633ca
Show file tree
Hide file tree
Showing 12 changed files with 584 additions and 133 deletions.
6 changes: 3 additions & 3 deletions api/auth/backends/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def authenticate( # type: ignore[override]
if not response.ok:
return None

auth_data: JsonDict = response.json()
if "error" in auth_data:
access_token: JsonDict = response.json()
if "error" in access_token:
return None

return Contributor.sync_with_github(
auth=f"{auth_data['token_type']} {auth_data['access_token']}"
auth=f"{access_token['token_type']} {access_token['access_token']}"
)

# pylint: disable-next=arguments-renamed
Expand Down
1 change: 1 addition & 0 deletions api/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .model_serializer import ModelSerializer
from .model_serializer_test_case import ModelSerializerTestCase
from .model_view_set import ModelViewSet
from .model_view_set_test_case import ModelViewSetTestCase
from .request import Request
212 changes: 212 additions & 0 deletions api/common/api_request_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
"""
© Ocado Group
Created on 08/02/2024 at 15:42:25(+00:00).
"""

import typing as t

from django.core.handlers.wsgi import WSGIRequest
from rest_framework.parsers import (
FileUploadParser,
FormParser,
JSONParser,
MultiPartParser,
)
from rest_framework.test import APIRequestFactory as _APIRequestFactory

from ..models import Contributor
from .request import Request


class APIRequestFactory(_APIRequestFactory):
"""Custom API request factory that returns DRF's Request object."""

def request(self, user: t.Optional[Contributor] = None, **kwargs):
wsgi_request = t.cast(WSGIRequest, super().request(**kwargs))

request = Request(
wsgi_request,
parsers=[
JSONParser(),
FormParser(),
MultiPartParser(),
FileUploadParser(),
],
)

if user:
# pylint: disable-next=attribute-defined-outside-init
request.user = user

return request

# pylint: disable-next=too-many-arguments
def generic(
self,
method: str,
path: t.Optional[str] = None,
data: t.Optional[str] = None,
content_type: t.Optional[str] = None,
secure: bool = True,
user: t.Optional[Contributor] = None,
**extra
):
return t.cast(
Request,
super().generic(
method,
path or "/",
data or "",
content_type or "application/json",
secure,
user=user,
**extra,
),
)

def get( # type: ignore[override]
self,
path: t.Optional[str] = None,
data: t.Any = None,
user: t.Optional[Contributor] = None,
**extra
):
return t.cast(
Request,
super().get(
path or "/",
data,
user=user,
**extra,
),
)

# pylint: disable-next=too-many-arguments
def post( # type: ignore[override]
self,
path: t.Optional[str] = None,
data: t.Any = None,
# pylint: disable-next=redefined-builtin
format: t.Optional[str] = None,
content_type: t.Optional[str] = None,
user: t.Optional[Contributor] = None,
**extra
):
if format is None and content_type is None:
format = "json"

return t.cast(
Request,
super().post(
path or "/",
data,
format,
content_type,
user=user,
**extra,
),
)

# pylint: disable-next=too-many-arguments
def put( # type: ignore[override]
self,
path: t.Optional[str] = None,
data: t.Any = None,
# pylint: disable-next=redefined-builtin
format: t.Optional[str] = None,
content_type: t.Optional[str] = None,
user: t.Optional[Contributor] = None,
**extra
):
if format is None and content_type is None:
format = "json"

return t.cast(
Request,
super().put(
path or "/",
data,
format,
content_type,
user=user,
**extra,
),
)

# pylint: disable-next=too-many-arguments
def patch( # type: ignore[override]
self,
path: t.Optional[str] = None,
data: t.Any = None,
# pylint: disable-next=redefined-builtin
format: t.Optional[str] = None,
content_type: t.Optional[str] = None,
user: t.Optional[Contributor] = None,
**extra
):
if format is None and content_type is None:
format = "json"

return t.cast(
Request,
super().patch(
path or "/",
data,
format,
content_type,
user=user,
**extra,
),
)

# pylint: disable-next=too-many-arguments
def delete( # type: ignore[override]
self,
path: t.Optional[str] = None,
data: t.Any = None,
# pylint: disable-next=redefined-builtin
format: t.Optional[str] = None,
content_type: t.Optional[str] = None,
user: t.Optional[Contributor] = None,
**extra
):
if format is None and content_type is None:
format = "json"

return t.cast(
Request,
super().delete(
path or "/",
data,
format,
content_type,
user=user,
**extra,
),
)

# pylint: disable-next=too-many-arguments
def options( # type: ignore[override]
self,
path: t.Optional[str] = None,
data: t.Optional[t.Union[t.Dict[str, str], str]] = None,
# pylint: disable-next=redefined-builtin
format: t.Optional[str] = None,
content_type: t.Optional[str] = None,
user: t.Optional[Contributor] = None,
**extra
):
if format is None and content_type is None:
format = "json"

return t.cast(
Request,
super().options(
path or "/",
data or {},
format,
content_type,
user=user,
**extra,
),
)
1 change: 0 additions & 1 deletion api/common/model_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from codeforlife.types import DataDict
from django.db.models import Model
from django.views import View
from rest_framework.serializers import ModelSerializer as _ModelSerializer

from .request import Request
Expand Down
47 changes: 47 additions & 0 deletions api/common/model_serializer_test_case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import typing as t

from codeforlife.tests import (
ModelSerializerTestCase as _ModelSerializerTestCase,
)
from django.db.models import Model

from ..models import Contributor
from .api_request_factory import APIRequestFactory
from .model_serializer import ModelSerializer

AnyModel = t.TypeVar("AnyModel", bound=Model)


class ModelSerializerTestCase(_ModelSerializerTestCase, t.Generic[AnyModel]):
model_serializer_class: t.Type[ModelSerializer[AnyModel]]

request_factory: APIRequestFactory

@classmethod
def setUpClass(cls):
result = super().setUpClass()

cls.request_factory = APIRequestFactory()

return result

@classmethod
def get_request_user_class(cls) -> t.Type[AnyModel]:
"""Get the model view set's class.
Returns:
The model view set's class.
"""
return Contributor

@classmethod
def get_model_class(cls) -> t.Type[AnyModel]:
"""Get the model view set's class.
Returns:
The model view set's class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]
Loading

0 comments on commit fd633ca

Please sign in to comment.