Skip to content

Commit

Permalink
abstract model view set test case and client
Browse files Browse the repository at this point in the history
  • Loading branch information
SKairinos committed Nov 7, 2024
1 parent 04966b5 commit ab235bf
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 42 deletions.
118 changes: 85 additions & 33 deletions codeforlife/tests/model_view_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,20 @@
import typing as t
from datetime import datetime

from django.contrib.auth import get_user_model
from django.core.exceptions import ObjectDoesNotExist
from django.db.models import Model
from django.urls import reverse

from ..models import AbstractBaseUser
from ..permissions import Permission
from ..serializers import BaseSerializer
from ..types import DataDict, JsonDict, KwArgs
from ..views import ModelViewSet
from .api import APITestCase
from .model_view_set_client import ModelViewSetClient
from ..views import BaseModelViewSet, ModelViewSet
from .api import APITestCase, BaseAPITestCase
from .model_view_set_client import BaseModelViewSetClient, ModelViewSetClient

# pylint: disable-next=duplicate-code
# pylint: disable=duplicate-code
if t.TYPE_CHECKING:
from ..user.models import User

Expand All @@ -28,21 +30,32 @@
RequestUser = t.TypeVar("RequestUser")

AnyModel = t.TypeVar("AnyModel", bound=Model)
# pylint: disable=no-member,too-many-arguments
AnyBaseModelViewSetClient = t.TypeVar(
"AnyBaseModelViewSetClient", bound=BaseModelViewSetClient
)
AnyBaseModelViewSet = t.TypeVar("AnyBaseModelViewSet", bound=BaseModelViewSet)
# pylint: enable=duplicate-code


# pylint: disable-next=too-many-ancestors
class ModelViewSetTestCase(
APITestCase[RequestUser], t.Generic[RequestUser, AnyModel]
class BaseModelViewSetTestCase(
BaseAPITestCase[AnyBaseModelViewSetClient],
t.Generic[AnyBaseModelViewSet, AnyBaseModelViewSetClient, AnyModel],
):
"""Base for all model view set test cases."""

basename: str
model_view_set_class: t.Type[ModelViewSet[RequestUser, AnyModel]]
client: ModelViewSetClient[RequestUser, AnyModel]
client_class: t.Type[ModelViewSetClient[RequestUser, AnyModel]] = (
ModelViewSetClient
)
model_view_set_class: t.Type[AnyBaseModelViewSet]

REQUIRED_ATTRS: t.Set[str] = {"model_view_set_class", "basename"}

@classmethod
def get_request_user_class(cls):
"""Get the request's user class.
Returns:
The request's user class.
"""
return t.cast(AbstractBaseUser, get_user_model())

@classmethod
def get_model_class(cls) -> t.Type[AnyModel]:
Expand All @@ -53,29 +66,16 @@ def get_model_class(cls) -> t.Type[AnyModel]:
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
1
2
]

@classmethod
def setUpClass(cls):
for attr in ["model_view_set_class", "basename"]:
for attr in cls.REQUIRED_ATTRS:
assert hasattr(cls, attr), f'Attribute "{attr}" must be set.'

return super().setUpClass()

def _get_client_class(self):
# TODO: unpack type args in index after moving to python 3.11
# pylint: disable-next=too-few-public-methods
class _Client(
self.client_class[ # type: ignore[misc]
self.get_request_user_class(),
self.get_model_class(),
]
):
_test_case = self

return _Client

def reverse_action(
self,
name: str,
Expand Down Expand Up @@ -113,6 +113,7 @@ def reverse_action(
# Assertion Helpers
# --------------------------------------------------------------------------

# pylint: disable-next=too-many-arguments
def assert_serialized_model_equals_json_model(
self,
model: AnyModel,
Expand All @@ -135,18 +136,16 @@ def assert_serialized_model_equals_json_model(
"""
# Get the logged-in user.
try:
user = t.cast(
RequestUser,
self.get_request_user_class().objects.get(
session=self.client.session.session_key
),
user = self.get_request_user_class().objects.get(
session=self.client.session.session_key
)
except ObjectDoesNotExist:
user = None # NOTE: no user has logged in.

# Create an instance of the model view set and serializer.
model_view_set = self.model_view_set_class(
action=action.replace("-", "_"),
# pylint: disable-next=no-member
request=self.client.request_factory.generic(
request_method, user=user
),
Expand Down Expand Up @@ -239,6 +238,7 @@ def assert_get_serializer_context(
serializer_context: The serializer's context.
action: The model view set's action.
"""
# pylint: disable-next=no-member
kwargs.setdefault("request", self.client.request_factory.get())
kwargs.setdefault("format_kwarg", None)
model_view_set = self.model_view_set_class(
Expand All @@ -249,3 +249,55 @@ def assert_get_serializer_context(
actual_serializer_context | serializer_context,
actual_serializer_context,
)


# pylint: disable-next=too-many-ancestors
class ModelViewSetTestCase(
BaseModelViewSetTestCase[
ModelViewSet[RequestUser, AnyModel],
ModelViewSetClient[RequestUser, AnyModel],
AnyModel,
],
APITestCase[RequestUser],
t.Generic[RequestUser, AnyModel],
):
"""Base for all model view set test cases."""

client_class = ModelViewSetClient

@classmethod
def get_request_user_class(cls) -> t.Type[RequestUser]:
"""Get the request's user class.
Returns:
The request's user class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]

@classmethod
def get_model_class(cls) -> t.Type[AnyModel]:
"""Get the model view set's class.
Returns:
The model view set's class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
1
]

def _get_client_class(self):
# TODO: unpack type args in index after moving to python 3.11
# pylint: disable-next=too-few-public-methods
class _Client(
self.client_class[ # type: ignore[misc]
self.get_request_user_class(),
self.get_model_class(),
]
):
_test_case = self

return _Client
43 changes: 34 additions & 9 deletions codeforlife/tests/model_view_set_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,39 @@
from rest_framework.response import Response

from ..types import DataDict, JsonDict, KwArgs
from .api import APIClient
from .api import APIClient, BaseAPIClient
from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory

# pylint: disable-next=duplicate-code
# pylint: disable=duplicate-code
if t.TYPE_CHECKING:
from ..user.models import User
from .model_view_set import ModelViewSetTestCase
from .model_view_set import BaseModelViewSetTestCase, ModelViewSetTestCase

RequestUser = t.TypeVar("RequestUser", bound=User)
AnyBaseModelViewSetTestCase = t.TypeVar(
"AnyBaseModelViewSetTestCase", bound=BaseModelViewSetTestCase
)
else:
RequestUser = t.TypeVar("RequestUser")
AnyBaseModelViewSetTestCase = t.TypeVar("AnyBaseModelViewSetTestCase")

AnyModel = t.TypeVar("AnyModel", bound=Model)
# pylint: disable=no-member,too-many-arguments
AnyBaseAPIRequestFactory = t.TypeVar(
"AnyBaseAPIRequestFactory", bound=BaseAPIRequestFactory
)
# pylint: enable=duplicate-code


# pylint: disable-next=too-many-ancestors
class ModelViewSetClient(
APIClient[RequestUser], t.Generic[RequestUser, AnyModel]
class BaseModelViewSetClient(
BaseAPIClient[AnyBaseModelViewSetTestCase, AnyBaseAPIRequestFactory],
t.Generic[AnyBaseModelViewSetTestCase, AnyBaseAPIRequestFactory],
):
"""
An API client that helps make requests to a model view set and assert their
responses.
"""

_test_case: "ModelViewSetTestCase[RequestUser, AnyModel]"

@property
def _model_class(self):
"""Shortcut to get model class."""
Expand Down Expand Up @@ -197,6 +204,7 @@ def retrieve(

return response

# pylint: disable-next=too-many-arguments
def list(
self,
models: t.Collection[AnyModel],
Expand Down Expand Up @@ -258,6 +266,7 @@ def _make_assertions(response_json: JsonDict):
# Partial Update (HTTP PATCH)
# --------------------------------------------------------------------------

# pylint: disable-next=too-many-arguments
def _assert_update(
self,
model: AnyModel,
Expand All @@ -271,6 +280,7 @@ def _assert_update(
model, json_model, action, request_method, contains_subset=partial
)

# pylint: disable-next=too-many-arguments
def partial_update(
self,
model: AnyModel,
Expand Down Expand Up @@ -321,6 +331,7 @@ def partial_update(

return response

# pylint: disable-next=too-many-arguments
def bulk_partial_update(
self,
models: t.Union[t.List[AnyModel], QuerySet[AnyModel]],
Expand Down Expand Up @@ -381,6 +392,7 @@ def _make_assertions(json_models: t.List[JsonDict]):
# Update (HTTP PUT)
# --------------------------------------------------------------------------

# pylint: disable-next=too-many-arguments
def update(
self,
model: AnyModel,
Expand Down Expand Up @@ -431,6 +443,7 @@ def update(

return response

# pylint: disable-next=too-many-arguments
def bulk_update(
self,
models: t.Union[t.List[AnyModel], QuerySet[AnyModel]],
Expand Down Expand Up @@ -605,4 +618,16 @@ def cron_job(self, action: str):
return response


# pylint: enable=no-member
# pylint: disable-next=too-many-ancestors
class ModelViewSetClient( # type: ignore[misc]
BaseModelViewSetClient[
"ModelViewSetTestCase[RequestUser, AnyModel]",
APIRequestFactory[RequestUser],
],
APIClient[RequestUser],
t.Generic[RequestUser, AnyModel],
):
"""
An API client that helps make requests to a model view set and assert their
responses.
"""

0 comments on commit ab235bf

Please sign in to comment.