From 702e656968bb7ce308727b2f08fac7de76b97c28 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Wed, 24 Jan 2024 15:17:23 +0000 Subject: [PATCH] fix: model view set test case --- codeforlife/tests/model_view_set.py | 120 ++++++++++---------- codeforlife/user/tests/views/__init__.py | 4 + codeforlife/user/tests/views/test_klass.py | 6 +- codeforlife/user/tests/views/test_school.py | 6 +- codeforlife/user/tests/views/test_user.py | 4 +- codeforlife/user/views/__init__.py | 5 + codeforlife/user/views/klass.py | 21 +++- codeforlife/user/views/school.py | 19 +++- codeforlife/user/views/user.py | 14 ++- codeforlife/views/__init__.py | 6 + codeforlife/views/base.py | 35 ++++++ pyproject.toml | 24 +++- 12 files changed, 176 insertions(+), 88 deletions(-) create mode 100644 codeforlife/views/base.py diff --git a/codeforlife/tests/model_view_set.py b/codeforlife/tests/model_view_set.py index 6140585f..ff85fc45 100644 --- a/codeforlife/tests/model_view_set.py +++ b/codeforlife/tests/model_view_set.py @@ -17,21 +17,16 @@ from pyotp import TOTP from rest_framework import status from rest_framework.response import Response -from rest_framework.serializers import ModelSerializer from rest_framework.test import APIClient, APIRequestFactory, APITestCase -from rest_framework.viewsets import ModelViewSet +from ..serializers import ModelSerializer from ..user.models import AuthFactor, User +from ..views import ModelViewSet -AnyModelViewSet = t.TypeVar("AnyModelViewSet", bound=ModelViewSet) -AnyModelSerializer = t.TypeVar("AnyModelSerializer", bound=ModelSerializer) AnyModel = t.TypeVar("AnyModel", bound=Model) -class ModelViewSetClient( - APIClient, - t.Generic[AnyModelViewSet, AnyModelSerializer, AnyModel], -): +class ModelViewSetClient(APIClient, t.Generic[AnyModel]): """ An API client that helps make requests to a model view set and assert their responses. @@ -44,7 +39,7 @@ def __init__(self, enforce_csrf_checks: bool = False, **defaults): **defaults, ) - _test_case: "ModelViewSetTestCase[AnyModelViewSet, AnyModelSerializer, AnyModel]" + _test_case: "ModelViewSetTestCase[AnyModel]" @property def _model_class(self): @@ -53,19 +48,12 @@ def _model_class(self): # pylint: disable-next=no-member return self._test_case.get_model_class() - @property - def _model_serializer_class(self): - """Shortcut to get model serializer class.""" - - # pylint: disable-next=no-member - return self._test_case.get_model_serializer_class() - @property def _model_view_set_class(self): """Shortcut to get model view set class.""" # pylint: disable-next=no-member - return self._test_case.get_model_view_set_class() + return self._test_case.model_view_set_class Data = t.Dict[str, t.Any] StatusCodeAssertion = t.Optional[t.Union[int, t.Callable[[int], bool]]] @@ -89,6 +77,9 @@ def assert_data_equals_model( self, data: Data, model: AnyModel, + model_serializer_class: t.Optional[ + t.Type[ModelSerializer[AnyModel]] + ] = None, contains_subset: bool = False, ): # pylint: disable=line-too-long @@ -115,7 +106,14 @@ def parse_data(data): return data.strftime("%Y-%m-%dT%H:%M:%S.%fZ") return data - actual_data = parse_data(self._model_serializer_class(model).data) + if model_serializer_class is None: + model_serializer_class = ( + # pylint: disable-next=no-member + self._test_case.model_serializer_class + or self._model_view_set_class().get_serializer_class() + ) + + actual_data = parse_data(model_serializer_class(model).data) if contains_subset: # pylint: disable-next=no-member @@ -244,17 +242,23 @@ def retrieve( self, model: AnyModel, status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK, + model_serializer_class: t.Optional[ + t.Type[ModelSerializer[AnyModel]] + ] = None, **kwargs, ): + # pylint: disable=line-too-long """Retrieve a model. Args: model: The model to retrieve. status_code_assertion: The expected status code. + model_serializer_class: The serializer used to serialize the model's data. Returns: The HTTP response. """ + # pylint: enable=line-too-long response: Response = self.get( self.reverse("detail", model), @@ -266,6 +270,7 @@ def retrieve( self.assert_data_equals_model( response.json(), # type: ignore[attr-defined] model, + model_serializer_class, ) return response @@ -274,19 +279,25 @@ def list( self, models: t.Iterable[AnyModel], status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK, + model_serializer_class: t.Optional[ + t.Type[ModelSerializer[AnyModel]] + ] = None, filters: ListFilters = None, **kwargs, ): + # pylint: disable=line-too-long """Retrieve a list of models. Args: models: The model list to retrieve. status_code_assertion: The expected status code. + model_serializer_class: The serializer used to serialize the model's data. filters: The filters to apply to the list. Returns: The HTTP response. """ + # pylint: enable=line-too-long assert self._model_class.objects.difference( self._model_class.objects.filter( @@ -301,8 +312,15 @@ def list( ) if self.status_code_is_ok(response.status_code): - for data, model in zip(response.json()["data"], models): # type: ignore[attr-defined] - self.assert_data_equals_model(data, model) + for data, model in zip( + response.json()["data"], # type: ignore[attr-defined] + models, + ): + self.assert_data_equals_model( + data, + model, + model_serializer_class, + ) return response @@ -311,18 +329,24 @@ def partial_update( model: AnyModel, data: Data, status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK, + model_serializer_class: t.Optional[ + t.Type[ModelSerializer[AnyModel]] + ] = None, **kwargs, ): + # pylint: disable=line-too-long """Partially update a model. Args: model: The model to partially update. data: The values for each field. status_code_assertion: The expected status code. + model_serializer_class: The serializer used to serialize the model's data. Returns: The HTTP response. """ + # pylint: enable=line-too-long response: Response = self.patch( self.reverse("detail", model), @@ -336,6 +360,7 @@ def partial_update( self.assert_data_equals_model( response.json(), # type: ignore[attr-defined] model, + model_serializer_class, contains_subset=True, ) @@ -367,7 +392,9 @@ def destroy( if not anonymized and self.status_code_is_ok(response.status_code): # pylint: disable-next=no-member - with self._test_case.assertRaises(model.DoesNotExist): + with self._test_case.assertRaises( + model.DoesNotExist # type: ignore[attr-defined] + ): model.refresh_from_db() return response @@ -438,56 +465,20 @@ def login_indy(self, **credentials): return user -class ModelViewSetTestCase( - APITestCase, - t.Generic[AnyModelViewSet, AnyModelSerializer, AnyModel], -): +class ModelViewSetTestCase(APITestCase, t.Generic[AnyModel]): """Base for all model view set test cases.""" basename: str - client: ModelViewSetClient[ # type: ignore[assignment] - AnyModelViewSet, - AnyModelSerializer, - AnyModel, - ] + model_view_set_class: t.Type[ModelViewSet[AnyModel]] + model_serializer_class: t.Optional[t.Type[ModelSerializer[AnyModel]]] = None + client: ModelViewSetClient[AnyModel] client_class = ModelViewSetClient # type: ignore[assignment] def _pre_setup(self): - super()._pre_setup() + super()._pre_setup() # type: ignore[misc] # pylint: disable-next=protected-access self.client._test_case = self - @classmethod - def _get_generic_args( - cls, - ) -> t.Tuple[ - t.Type[AnyModelViewSet], - t.Type[AnyModelSerializer], - t.Type[AnyModel], - ]: - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0]) # type: ignore[attr-defined,return-value] - - @classmethod - def get_model_view_set_class(cls): - """Get the model view set's class. - - Returns: - The model view set's class. - """ - - return cls._get_generic_args()[0] - - @classmethod - def get_model_serializer_class(cls): - """Get the model serializer's class. - - Returns: - The model serializer's class. - """ - - return cls._get_generic_args()[1] - @classmethod def get_model_class(cls): """Get the model view set's class. @@ -496,7 +487,10 @@ def get_model_class(cls): The model view set's class. """ - return cls._get_generic_args()[2] + # pylint: disable-next=no-member + return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] + 0 + ] def get_other_user( self, diff --git a/codeforlife/user/tests/views/__init__.py b/codeforlife/user/tests/views/__init__.py index e69de29b..fcf35baf 100644 --- a/codeforlife/user/tests/views/__init__.py +++ b/codeforlife/user/tests/views/__init__.py @@ -0,0 +1,4 @@ +""" +© Ocado Group +Created on 24/01/2024 at 13:52:24(+00:00). +""" diff --git a/codeforlife/user/tests/views/test_klass.py b/codeforlife/user/tests/views/test_klass.py index 9db22791..eda41b04 100644 --- a/codeforlife/user/tests/views/test_klass.py +++ b/codeforlife/user/tests/views/test_klass.py @@ -5,13 +5,10 @@ from ....tests import ModelViewSetTestCase from ...models import Class -from ...serializers import ClassSerializer from ...views import ClassViewSet -class TestClassViewSet( - ModelViewSetTestCase[ClassViewSet, ClassSerializer, Class] -): +class TestClassViewSet(ModelViewSetTestCase[Class]): """ Base naming convention: test_{action} @@ -21,6 +18,7 @@ class TestClassViewSet( """ basename = "class" + model_view_set_class = ClassViewSet def _login_student(self): return self.client.login_student( diff --git a/codeforlife/user/tests/views/test_school.py b/codeforlife/user/tests/views/test_school.py index ed6c792b..4b400521 100644 --- a/codeforlife/user/tests/views/test_school.py +++ b/codeforlife/user/tests/views/test_school.py @@ -8,13 +8,10 @@ from ....tests import ModelViewSetTestCase from ...models import Class, School, Student, Teacher, User, UserProfile -from ...serializers import SchoolSerializer from ...views import SchoolViewSet -class TestSchoolViewSet( - ModelViewSetTestCase[SchoolViewSet, SchoolSerializer, School] -): +class TestSchoolViewSet(ModelViewSetTestCase[School]): """ Base naming convention: test_{action} @@ -24,6 +21,7 @@ class TestSchoolViewSet( """ basename = "school" + model_view_set_class = SchoolViewSet # TODO: replace this setup with data fixtures. def setUp(self): diff --git a/codeforlife/user/tests/views/test_user.py b/codeforlife/user/tests/views/test_user.py index 7ca529ed..793dfc63 100644 --- a/codeforlife/user/tests/views/test_user.py +++ b/codeforlife/user/tests/views/test_user.py @@ -8,12 +8,11 @@ from ....tests import ModelViewSetTestCase from ...models import Class, School, Student, Teacher, User, UserProfile -from ...serializers import UserSerializer from ...views import UserViewSet # pylint: disable-next=too-many-ancestors,too-many-public-methods -class TestUserViewSet(ModelViewSetTestCase[UserViewSet, UserSerializer, User]): +class TestUserViewSet(ModelViewSetTestCase[User]): """ Base naming convention: test_{action} @@ -23,6 +22,7 @@ class TestUserViewSet(ModelViewSetTestCase[UserViewSet, UserSerializer, User]): """ basename = "user" + model_view_set_class = UserViewSet # TODO: replace this setup with data fixtures. def setUp(self): diff --git a/codeforlife/user/views/__init__.py b/codeforlife/user/views/__init__.py index ffcd7b4e..06f853cb 100644 --- a/codeforlife/user/views/__init__.py +++ b/codeforlife/user/views/__init__.py @@ -1,3 +1,8 @@ +""" +© Ocado Group +Created on 24/01/2024 at 13:52:02(+00:00). +""" + from .klass import ClassViewSet from .school import SchoolViewSet from .user import UserViewSet diff --git a/codeforlife/user/views/klass.py b/codeforlife/user/views/klass.py index bc863e88..c803729b 100644 --- a/codeforlife/user/views/klass.py +++ b/codeforlife/user/views/klass.py @@ -1,23 +1,32 @@ +""" +© Ocado Group +Created on 24/01/2024 at 13:47:53(+00:00). +""" + +import typing as t + from rest_framework.permissions import IsAuthenticated -from rest_framework.viewsets import ModelViewSet +from ...views import ModelViewSet from ..models import Class, User from ..permissions import IsSchoolMember from ..serializers import ClassSerializer -class ClassViewSet(ModelViewSet): +# pylint: disable-next=missing-class-docstring,too-few-public-methods +class ClassViewSet(ModelViewSet[Class]): http_method_names = ["get"] lookup_field = "access_code" serializer_class = ClassSerializer permission_classes = [IsAuthenticated, IsSchoolMember] + # pylint: disable-next=missing-function-docstring def get_queryset(self): - user: User = self.request.user + user = t.cast(User, self.request.user) if user.is_student: return Class.objects.filter(students=user.student) - elif user.teacher.is_admin: + if user.teacher.is_admin: # TODO: add school field to class object return Class.objects.filter(teacher__school=user.teacher.school) - else: - return Class.objects.filter(teacher=user.teacher) + + return Class.objects.filter(teacher=user.teacher) diff --git a/codeforlife/user/views/school.py b/codeforlife/user/views/school.py index 0c12d9cd..099e7bfa 100644 --- a/codeforlife/user/views/school.py +++ b/codeforlife/user/views/school.py @@ -1,22 +1,31 @@ +""" +© Ocado Group +Created on 24/01/2024 at 13:38:15(+00:00). +""" + +import typing as t + from rest_framework.permissions import IsAuthenticated -from rest_framework.viewsets import ModelViewSet +from ...views import ModelViewSet from ..models import School, User from ..permissions import IsSchoolMember from ..serializers import SchoolSerializer -class SchoolViewSet(ModelViewSet): +# pylint: disable-next=missing-class-docstring,too-few-public-methods +class SchoolViewSet(ModelViewSet[School]): http_method_names = ["get"] serializer_class = SchoolSerializer permission_classes = [IsAuthenticated, IsSchoolMember] + # pylint: disable-next=missing-function-docstring def get_queryset(self): - user: User = self.request.user + user = t.cast(User, self.request.user) if user.is_student: return School.objects.filter( # TODO: should be user.student.school_id id=user.student.class_field.teacher.school_id ) - else: - return School.objects.filter(id=user.teacher.school_id) + + return School.objects.filter(id=user.teacher.school_id) diff --git a/codeforlife/user/views/user.py b/codeforlife/user/views/user.py index 9efa496c..39c18e59 100644 --- a/codeforlife/user/views/user.py +++ b/codeforlife/user/views/user.py @@ -1,17 +1,25 @@ -from rest_framework.viewsets import ModelViewSet +""" +© Ocado Group +Created on 24/01/2024 at 13:12:05(+00:00). +""" +import typing as t + +from ...views import ModelViewSet from ..filters import UserFilterSet from ..models import User from ..serializers import UserSerializer -class UserViewSet(ModelViewSet): +# pylint: disable-next=missing-class-docstring,too-few-public-methods +class UserViewSet(ModelViewSet[User]): http_method_names = ["get"] serializer_class = UserSerializer filterset_class = UserFilterSet + # pylint: disable-next=missing-function-docstring def get_queryset(self): - user: User = self.request.user + user = t.cast(User, self.request.user) if user.is_student: if user.student.class_field is None: return User.objects.filter(id=user.id) diff --git a/codeforlife/views/__init__.py b/codeforlife/views/__init__.py index e69de29b..5b2d49eb 100644 --- a/codeforlife/views/__init__.py +++ b/codeforlife/views/__init__.py @@ -0,0 +1,6 @@ +""" +© Ocado Group +Created on 24/01/2024 at 13:07:38(+00:00). +""" + +from .base import ModelViewSet diff --git a/codeforlife/views/base.py b/codeforlife/views/base.py new file mode 100644 index 00000000..8e238753 --- /dev/null +++ b/codeforlife/views/base.py @@ -0,0 +1,35 @@ +""" +© Ocado Group +Created on 24/01/2024 at 13:08:23(+00:00). +""" + +import typing as t + +from django.db.models import Model +from rest_framework.viewsets import ModelViewSet as DrfModelViewSet + +AnyModel = t.TypeVar("AnyModel", bound=Model) + + +# pylint: disable-next=too-few-public-methods +class _ModelViewSet(t.Generic[AnyModel]): + pass + + +if t.TYPE_CHECKING: + # pylint: disable-next=too-few-public-methods + class ModelViewSet( + DrfModelViewSet[AnyModel], + _ModelViewSet[AnyModel], + t.Generic[AnyModel], + ): + """Base model view set for all model view sets.""" + +else: + # pylint: disable-next=missing-class-docstring,too-many-ancestors + class ModelViewSet( + DrfModelViewSet, + _ModelViewSet[AnyModel], + t.Generic[AnyModel], + ): + pass diff --git a/pyproject.toml b/pyproject.toml index 75a801c6..0d57c45b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,29 @@ upload_to_release = true [tool.black] line-length = 80 -extend-exclude = "^/codeforlife/user/migrations/" +extend-exclude = ".*/migrations/.*py" [tool.pytest.ini_options] env = ["DJANGO_SETTINGS_MODULE=manage"] + +[tool.mypy] +plugins = ["mypy_django_plugin.main", "mypy_drf_plugin.main"] +check_untyped_defs = true + +[tool.django-stubs] +django_settings_module = "manage" + +[tool.pylint.main] +init-hook = "import os; os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'manage')" + +[tool.pylint.format] +max-line-length = 80 + +[tool.pylint.MASTER] +ignore-paths = [".*/migrations/.*py"] +load-plugins = "pylint_django" + +[tool.isort] +profile = "black" +line_length = 80 +skip_glob = ["**/migrations/*.py"]