Skip to content

Commit

Permalink
fix: model view set test case
Browse files Browse the repository at this point in the history
  • Loading branch information
SKairinos committed Jan 24, 2024
1 parent 77d6413 commit 702e656
Show file tree
Hide file tree
Showing 12 changed files with 176 additions and 88 deletions.
120 changes: 57 additions & 63 deletions codeforlife/tests/model_view_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,16 @@
from pyotp import TOTP
from rest_framework import status
from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer
from rest_framework.test import APIClient, APIRequestFactory, APITestCase
from rest_framework.viewsets import ModelViewSet

from ..serializers import ModelSerializer
from ..user.models import AuthFactor, User
from ..views import ModelViewSet

AnyModelViewSet = t.TypeVar("AnyModelViewSet", bound=ModelViewSet)
AnyModelSerializer = t.TypeVar("AnyModelSerializer", bound=ModelSerializer)
AnyModel = t.TypeVar("AnyModel", bound=Model)


class ModelViewSetClient(
APIClient,
t.Generic[AnyModelViewSet, AnyModelSerializer, AnyModel],
):
class ModelViewSetClient(APIClient, t.Generic[AnyModel]):
"""
An API client that helps make requests to a model view set and assert their
responses.
Expand All @@ -44,7 +39,7 @@ def __init__(self, enforce_csrf_checks: bool = False, **defaults):
**defaults,
)

_test_case: "ModelViewSetTestCase[AnyModelViewSet, AnyModelSerializer, AnyModel]"
_test_case: "ModelViewSetTestCase[AnyModel]"

@property
def _model_class(self):
Expand All @@ -53,19 +48,12 @@ def _model_class(self):
# pylint: disable-next=no-member
return self._test_case.get_model_class()

@property
def _model_serializer_class(self):
"""Shortcut to get model serializer class."""

# pylint: disable-next=no-member
return self._test_case.get_model_serializer_class()

@property
def _model_view_set_class(self):
"""Shortcut to get model view set class."""

# pylint: disable-next=no-member
return self._test_case.get_model_view_set_class()
return self._test_case.model_view_set_class

Data = t.Dict[str, t.Any]
StatusCodeAssertion = t.Optional[t.Union[int, t.Callable[[int], bool]]]
Expand All @@ -89,6 +77,9 @@ def assert_data_equals_model(
self,
data: Data,
model: AnyModel,
model_serializer_class: t.Optional[
t.Type[ModelSerializer[AnyModel]]
] = None,
contains_subset: bool = False,
):
# pylint: disable=line-too-long
Expand All @@ -115,7 +106,14 @@ def parse_data(data):
return data.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
return data

actual_data = parse_data(self._model_serializer_class(model).data)
if model_serializer_class is None:
model_serializer_class = (
# pylint: disable-next=no-member
self._test_case.model_serializer_class
or self._model_view_set_class().get_serializer_class()
)

actual_data = parse_data(model_serializer_class(model).data)

if contains_subset:
# pylint: disable-next=no-member
Expand Down Expand Up @@ -244,17 +242,23 @@ def retrieve(
self,
model: AnyModel,
status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK,
model_serializer_class: t.Optional[
t.Type[ModelSerializer[AnyModel]]
] = None,
**kwargs,
):
# pylint: disable=line-too-long
"""Retrieve a model.
Args:
model: The model to retrieve.
status_code_assertion: The expected status code.
model_serializer_class: The serializer used to serialize the model's data.
Returns:
The HTTP response.
"""
# pylint: enable=line-too-long

response: Response = self.get(
self.reverse("detail", model),
Expand All @@ -266,6 +270,7 @@ def retrieve(
self.assert_data_equals_model(
response.json(), # type: ignore[attr-defined]
model,
model_serializer_class,
)

return response
Expand All @@ -274,19 +279,25 @@ def list(
self,
models: t.Iterable[AnyModel],
status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK,
model_serializer_class: t.Optional[
t.Type[ModelSerializer[AnyModel]]
] = None,
filters: ListFilters = None,
**kwargs,
):
# pylint: disable=line-too-long
"""Retrieve a list of models.
Args:
models: The model list to retrieve.
status_code_assertion: The expected status code.
model_serializer_class: The serializer used to serialize the model's data.
filters: The filters to apply to the list.
Returns:
The HTTP response.
"""
# pylint: enable=line-too-long

assert self._model_class.objects.difference(
self._model_class.objects.filter(
Expand All @@ -301,8 +312,15 @@ def list(
)

if self.status_code_is_ok(response.status_code):
for data, model in zip(response.json()["data"], models): # type: ignore[attr-defined]
self.assert_data_equals_model(data, model)
for data, model in zip(
response.json()["data"], # type: ignore[attr-defined]
models,
):
self.assert_data_equals_model(
data,
model,
model_serializer_class,
)

return response

Expand All @@ -311,18 +329,24 @@ def partial_update(
model: AnyModel,
data: Data,
status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK,
model_serializer_class: t.Optional[
t.Type[ModelSerializer[AnyModel]]
] = None,
**kwargs,
):
# pylint: disable=line-too-long
"""Partially update a model.
Args:
model: The model to partially update.
data: The values for each field.
status_code_assertion: The expected status code.
model_serializer_class: The serializer used to serialize the model's data.
Returns:
The HTTP response.
"""
# pylint: enable=line-too-long

response: Response = self.patch(
self.reverse("detail", model),
Expand All @@ -336,6 +360,7 @@ def partial_update(
self.assert_data_equals_model(
response.json(), # type: ignore[attr-defined]
model,
model_serializer_class,
contains_subset=True,
)

Expand Down Expand Up @@ -367,7 +392,9 @@ def destroy(

if not anonymized and self.status_code_is_ok(response.status_code):
# pylint: disable-next=no-member
with self._test_case.assertRaises(model.DoesNotExist):
with self._test_case.assertRaises(
model.DoesNotExist # type: ignore[attr-defined]
):
model.refresh_from_db()

return response
Expand Down Expand Up @@ -438,56 +465,20 @@ def login_indy(self, **credentials):
return user


class ModelViewSetTestCase(
APITestCase,
t.Generic[AnyModelViewSet, AnyModelSerializer, AnyModel],
):
class ModelViewSetTestCase(APITestCase, t.Generic[AnyModel]):
"""Base for all model view set test cases."""

basename: str
client: ModelViewSetClient[ # type: ignore[assignment]
AnyModelViewSet,
AnyModelSerializer,
AnyModel,
]
model_view_set_class: t.Type[ModelViewSet[AnyModel]]
model_serializer_class: t.Optional[t.Type[ModelSerializer[AnyModel]]] = None
client: ModelViewSetClient[AnyModel]
client_class = ModelViewSetClient # type: ignore[assignment]

def _pre_setup(self):
super()._pre_setup()
super()._pre_setup() # type: ignore[misc]
# pylint: disable-next=protected-access
self.client._test_case = self

@classmethod
def _get_generic_args(
cls,
) -> t.Tuple[
t.Type[AnyModelViewSet],
t.Type[AnyModelSerializer],
t.Type[AnyModel],
]:
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0]) # type: ignore[attr-defined,return-value]

@classmethod
def get_model_view_set_class(cls):
"""Get the model view set's class.
Returns:
The model view set's class.
"""

return cls._get_generic_args()[0]

@classmethod
def get_model_serializer_class(cls):
"""Get the model serializer's class.
Returns:
The model serializer's class.
"""

return cls._get_generic_args()[1]

@classmethod
def get_model_class(cls):
"""Get the model view set's class.
Expand All @@ -496,7 +487,10 @@ def get_model_class(cls):
The model view set's class.
"""

return cls._get_generic_args()[2]
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]

def get_other_user(
self,
Expand Down
4 changes: 4 additions & 0 deletions codeforlife/user/tests/views/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""
© Ocado Group
Created on 24/01/2024 at 13:52:24(+00:00).
"""
6 changes: 2 additions & 4 deletions codeforlife/user/tests/views/test_klass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@

from ....tests import ModelViewSetTestCase
from ...models import Class
from ...serializers import ClassSerializer
from ...views import ClassViewSet


class TestClassViewSet(
ModelViewSetTestCase[ClassViewSet, ClassSerializer, Class]
):
class TestClassViewSet(ModelViewSetTestCase[Class]):
"""
Base naming convention:
test_{action}
Expand All @@ -21,6 +18,7 @@ class TestClassViewSet(
"""

basename = "class"
model_view_set_class = ClassViewSet

def _login_student(self):
return self.client.login_student(
Expand Down
6 changes: 2 additions & 4 deletions codeforlife/user/tests/views/test_school.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,10 @@

from ....tests import ModelViewSetTestCase
from ...models import Class, School, Student, Teacher, User, UserProfile
from ...serializers import SchoolSerializer
from ...views import SchoolViewSet


class TestSchoolViewSet(
ModelViewSetTestCase[SchoolViewSet, SchoolSerializer, School]
):
class TestSchoolViewSet(ModelViewSetTestCase[School]):
"""
Base naming convention:
test_{action}
Expand All @@ -24,6 +21,7 @@ class TestSchoolViewSet(
"""

basename = "school"
model_view_set_class = SchoolViewSet

# TODO: replace this setup with data fixtures.
def setUp(self):
Expand Down
4 changes: 2 additions & 2 deletions codeforlife/user/tests/views/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@

from ....tests import ModelViewSetTestCase
from ...models import Class, School, Student, Teacher, User, UserProfile
from ...serializers import UserSerializer
from ...views import UserViewSet


# pylint: disable-next=too-many-ancestors,too-many-public-methods
class TestUserViewSet(ModelViewSetTestCase[UserViewSet, UserSerializer, User]):
class TestUserViewSet(ModelViewSetTestCase[User]):
"""
Base naming convention:
test_{action}
Expand All @@ -23,6 +22,7 @@ class TestUserViewSet(ModelViewSetTestCase[UserViewSet, UserSerializer, User]):
"""

basename = "user"
model_view_set_class = UserViewSet

# TODO: replace this setup with data fixtures.
def setUp(self):
Expand Down
5 changes: 5 additions & 0 deletions codeforlife/user/views/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""
© Ocado Group
Created on 24/01/2024 at 13:52:02(+00:00).
"""

from .klass import ClassViewSet
from .school import SchoolViewSet
from .user import UserViewSet
21 changes: 15 additions & 6 deletions codeforlife/user/views/klass.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
"""
© Ocado Group
Created on 24/01/2024 at 13:47:53(+00:00).
"""

import typing as t

from rest_framework.permissions import IsAuthenticated
from rest_framework.viewsets import ModelViewSet

from ...views import ModelViewSet
from ..models import Class, User
from ..permissions import IsSchoolMember
from ..serializers import ClassSerializer


class ClassViewSet(ModelViewSet):
# pylint: disable-next=missing-class-docstring,too-few-public-methods
class ClassViewSet(ModelViewSet[Class]):
http_method_names = ["get"]
lookup_field = "access_code"
serializer_class = ClassSerializer
permission_classes = [IsAuthenticated, IsSchoolMember]

# pylint: disable-next=missing-function-docstring
def get_queryset(self):
user: User = self.request.user
user = t.cast(User, self.request.user)
if user.is_student:
return Class.objects.filter(students=user.student)
elif user.teacher.is_admin:
if user.teacher.is_admin:
# TODO: add school field to class object
return Class.objects.filter(teacher__school=user.teacher.school)
else:
return Class.objects.filter(teacher=user.teacher)

return Class.objects.filter(teacher=user.teacher)
Loading

0 comments on commit 702e656

Please sign in to comment.