diff --git a/codeforlife/tests/model_view_set.py b/codeforlife/tests/model_view_set.py new file mode 100644 index 00000000..a1b070c6 --- /dev/null +++ b/codeforlife/tests/model_view_set.py @@ -0,0 +1,328 @@ +""" +© Ocado Group +Created on 19/01/2024 at 17:06:45(+00:00). + +Base test case for all model view sets. +""" + +import typing as t +from datetime import datetime +from unittest.mock import patch + +from django.db.models import Model +from django.urls import reverse +from django.utils import timezone +from django.utils.http import urlencode +from pyotp import TOTP +from rest_framework.response import Response +from rest_framework.serializers import ModelSerializer +from rest_framework.test import APIClient, APITestCase +from rest_framework.viewsets import ModelViewSet + +from ..user.models import AuthFactor, User + +AnyModelViewSet = t.TypeVar("AnyModelViewSet", bound=ModelViewSet) +AnyModelSerializer = t.TypeVar("AnyModelSerializer", bound=ModelSerializer) +AnyModel = t.TypeVar("AnyModel", bound=Model) + + +class ModelViewSetClient( + APIClient, + t.Generic[AnyModelViewSet, AnyModelSerializer, AnyModel], +): + """ + An API client that helps make requests to a model view set and assert their + responses. + """ + + basename: str + model_class: t.Type[AnyModel] + model_serializer_class: t.Type[AnyModelSerializer] + model_view_set_class: t.Type[AnyModelViewSet] + + StatusCodeAssertion = t.Optional[t.Union[int, t.Callable[[int], bool]]] + ListFilters = t.Optional[t.Dict[str, str]] + + @staticmethod + def status_code_is_ok(status_code: int): + """Check if the status code is greater than or equal to 200 and less + than 300. + + Args: + status_code: The status code to check. + + Returns: + A flag designating if the status code is OK. + """ + + return 200 <= status_code < 300 + + def assert_data_equals_model( + self, + data: t.Dict[str, t.Any], + model: AnyModel, + ): + """Check if the data equals the current state of the model instance. + + Args: + data: The data to check. + model: The model instance. + model_serializer_class: The serializer used to serialize the model's data. + + Returns: + A flag designating if the data equals the current state of the model + instance. + """ + + def parse_data(data): + if isinstance(data, list): + return [parse_data(value) for value in data] + if isinstance(data, dict): + return {key: parse_data(value) for key, value in data.items()} + if isinstance(data, datetime): + return data.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + return data + + assert data == parse_data( + self.model_serializer_class(model).data + ), "Data does not equal serialized model." + + # pylint: disable-next=too-many-arguments + def generic( + self, + method, + path, + data="", + content_type="application/octet-stream", + secure=False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + response = t.cast( + Response, + super().generic( + method, + path, + data, + content_type, + secure, + **extra, + ), + ) + + # Use a custom kwarg to handle the common case of checking the + # response's status code. + if status_code_assertion is None: + status_code_assertion = self.status_code_is_ok + elif isinstance(status_code_assertion, int): + expected_status_code = status_code_assertion + status_code_assertion = ( + # pylint: disable-next=unnecessary-lambda-assignment + lambda status_code: status_code + == expected_status_code + ) + + # pylint: disable-next=no-member + status_code = response.status_code + assert status_code_assertion( + status_code + ), f"Unexpected status code: {status_code}." + + return response + + def retrieve( + self, + model: AnyModel, + status_code_assertion: StatusCodeAssertion = None, + **kwargs, + ): + """Retrieve a model from the view set. + + Args: + model: The model to retrieve. + status_code_assertion: The expected status code. + + Returns: + The HTTP response. + """ + + response: Response = self.get( + reverse( + f"{self.basename}-detail", + kwargs={ + self.model_view_set_class.lookup_field: getattr( + model, self.model_view_set_class.lookup_field + ) + }, + ), + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if self.status_code_is_ok(response.status_code): + self.assert_data_equals_model( + response.json(), # type: ignore[attr-defined] + model, + ) + + return response + + def list( + self, + models: t.Iterable[AnyModel], + status_code_assertion: StatusCodeAssertion = None, + filters: ListFilters = None, + **kwargs, + ): + """Retrieve a list of models from the view set. + + Args: + models: The model list to retrieve. + status_code_assertion: The expected status code. + filters: The filters to apply to the list. + + Returns: + The HTTP response. + """ + + assert self.model_class.objects.difference( + self.model_class.objects.filter( + pk__in=[model.pk for model in models] + ) + ).exists(), "List must exclude some models for a valid test." + + response: Response = self.get( + f"{reverse(f'{self.basename}-list')}?{urlencode(filters or {})}", + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if self.status_code_is_ok(response.status_code): + for data, model in zip(response.json()["data"], models): # type: ignore[attr-defined] + self.assert_data_equals_model(data, model) + + return response + + def login(self, **credentials): + assert super().login( + **credentials + ), f"Failed to login with credentials: {credentials}." + + user = User.objects.get(session=self.session.session_key) + + if user.session.session_auth_factors.filter( + auth_factor__type=AuthFactor.Type.OTP + ).exists(): + now = timezone.now() + otp = TOTP(user.otp_secret).at(now) + with patch.object(timezone, "now", return_value=now): + assert super().login( + otp=otp + ), f'Failed to login with OTP "{otp}" at {now}.' + + assert user.is_authenticated, "Failed to authenticate user." + + return user + + def login_teacher(self, is_admin: bool, **credentials): + """Log in a user and assert they are a teacher. + + Args: + is_admin: Whether or not the teacher is an admin. + + Returns: + The teacher-user. + """ + + user = self.login(**credentials) + assert user.teacher + assert user.teacher.school + assert is_admin == user.teacher.is_admin + return user + + def login_student(self, **credentials): + """Log in a user and assert they are a student. + + Returns: + The student-user. + """ + + user = self.login(**credentials) + assert user.student + assert user.student.class_field.teacher.school + return user + + def login_indy(self, **credentials): + """Log in an independent and assert they are a student. + + Returns: + The independent-user. + """ + + user = self.login(**credentials) + assert user.student + assert not user.student.class_field + return user + + +class ModelViewSetTestCase( + APITestCase, + t.Generic[AnyModelViewSet, AnyModelSerializer, AnyModel], +): + """Base for all model view set test cases.""" + + basename: str + client: ModelViewSetClient[ # type: ignore[assignment] + AnyModelViewSet, + AnyModelSerializer, + AnyModel, + ] + client_class = ModelViewSetClient # type: ignore[assignment] + + def _pre_setup(self): + super()._pre_setup() + self.client.basename = self.basename + self.client.model_view_set_class = self.get_model_view_set_class() + self.client.model_serializer_class = self.get_model_serializer_class() + self.client.model_class = self.get_model_class() + + @classmethod + def _get_generic_args( + cls, + ) -> t.Tuple[ + t.Type[AnyModelViewSet], + t.Type[AnyModelSerializer], + t.Type[AnyModel], + ]: + # pylint: disable-next=no-member + return t.get_args(cls.__orig_bases__[0]) # type: ignore[attr-defined,return-value] + + @classmethod + def get_model_view_set_class(cls): + """Get the model view set's class. + + Returns: + The model view set's class. + """ + + return cls._get_generic_args()[0] + + @classmethod + def get_model_serializer_class(cls): + """Get the model serializer's class. + + Returns: + The model serializer's class. + """ + + return cls._get_generic_args()[1] + + @classmethod + def get_model_class(cls): + """Get the model view set's class. + + Returns: + The model view set's class. + """ + + return cls._get_generic_args()[2] diff --git a/codeforlife/user/tests/views/test_user.py b/codeforlife/user/tests/views/test_user.py index 4cb9a1c9..f59ff0e9 100644 --- a/codeforlife/user/tests/views/test_user.py +++ b/codeforlife/user/tests/views/test_user.py @@ -1,15 +1,19 @@ -import typing as t +""" +© Ocado Group +Created on 19/01/2024 at 17:15:56(+00:00). +""" from rest_framework import status from rest_framework.permissions import IsAuthenticated -from ....tests import APIClient, APITestCase +from ....tests import ModelViewSetTestCase from ...models import Class, School, Student, Teacher, User, UserProfile from ...serializers import UserSerializer from ...views import UserViewSet -class TestUserViewSet(APITestCase): +# pylint: disable-next=too-many-ancestors,too-many-public-methods +class TestUserViewSet(ModelViewSetTestCase[UserViewSet, UserSerializer, User]): """ Base naming convention: test_{action} @@ -18,6 +22,8 @@ class TestUserViewSet(APITestCase): https://www.django-rest-framework.org/api-guide/viewsets/#viewset-actions """ + basename = "user" + # TODO: replace this setup with data fixtures. def setUp(self): school = School.objects.create( @@ -81,12 +87,13 @@ def _login_student(self): password="Password1", ) - def _login_indy_student(self): - return self.client.login_indy_student( + def _login_indy(self): + return self.client.login_indy( email="indianajones@codeforlife.com", password="Password1", ) + # pylint: disable-next=pointless-string-statement """ Retrieve naming convention: test_retrieve__{user_type}__{other_user_type}__{same_school}__{same_class} @@ -111,18 +118,6 @@ def _login_indy_student(self): - not_same_class: The other user is not from the same class. """ - def _retrieve_user( - self, - user: User, - status_code_assertion: APIClient.StatusCodeAssertion = None, - ): - return self.client.retrieve( - "user", - user, - UserSerializer, - status_code_assertion, - ) - def test_retrieve__teacher__self(self): """ Teacher can retrieve their own user data. @@ -130,7 +125,7 @@ def test_retrieve__teacher__self(self): user = self._login_teacher() - self._retrieve_user(user) + self.client.retrieve(user) def test_retrieve__student__self(self): """ @@ -139,16 +134,16 @@ def test_retrieve__student__self(self): user = self._login_student() - self._retrieve_user(user) + self.client.retrieve(user) def test_retrieve__indy_student__self(self): """ Independent student can retrieve their own user data. """ - user = self._login_indy_student() + user = self._login_indy() - self._retrieve_user(user) + self.client.retrieve(user) def test_retrieve__teacher__teacher__same_school(self): """ @@ -166,7 +161,7 @@ def test_retrieve__teacher__teacher__same_school(self): same_school=True, ) - self._retrieve_user(other_user) + self.client.retrieve(other_user) def test_retrieve__teacher__student__same_school__same_class(self): """ @@ -186,7 +181,7 @@ def test_retrieve__teacher__student__same_school__same_class(self): same_class=True, ) - self._retrieve_user(other_user) + self.client.retrieve(other_user) def test_retrieve__teacher__student__same_school__not_same_class(self): """ @@ -206,10 +201,7 @@ def test_retrieve__teacher__student__same_school__not_same_class(self): same_class=False, ) - self._retrieve_user( - other_user, - status_code_assertion=status.HTTP_404_NOT_FOUND, - ) + self.client.retrieve(other_user, status.HTTP_404_NOT_FOUND) def test_retrieve__admin_teacher__student__same_school__same_class(self): """ @@ -229,7 +221,7 @@ def test_retrieve__admin_teacher__student__same_school__same_class(self): same_class=True, ) - self._retrieve_user(other_user) + self.client.retrieve(other_user) def test_retrieve__admin_teacher__student__same_school__not_same_class( self, @@ -251,7 +243,7 @@ def test_retrieve__admin_teacher__student__same_school__not_same_class( same_class=False, ) - self._retrieve_user(other_user) + self.client.retrieve(other_user) def test_retrieve__student__teacher__same_school__same_class(self): """ @@ -271,10 +263,7 @@ def test_retrieve__student__teacher__same_school__same_class(self): same_class=True, ) - self._retrieve_user( - other_user, - status_code_assertion=status.HTTP_404_NOT_FOUND, - ) + self.client.retrieve(other_user, status.HTTP_404_NOT_FOUND) def test_retrieve__student__teacher__same_school__not_same_class(self): """ @@ -294,10 +283,7 @@ def test_retrieve__student__teacher__same_school__not_same_class(self): same_class=False, ) - self._retrieve_user( - other_user, - status_code_assertion=status.HTTP_404_NOT_FOUND, - ) + self.client.retrieve(other_user, status.HTTP_404_NOT_FOUND) def test_retrieve__student__student__same_school__same_class(self): """ @@ -317,7 +303,7 @@ def test_retrieve__student__student__same_school__same_class(self): same_class=True, ) - self._retrieve_user(other_user) + self.client.retrieve(other_user) def test_retrieve__student__student__same_school__not_same_class(self): """ @@ -339,10 +325,7 @@ def test_retrieve__student__student__same_school__not_same_class(self): same_class=False, ) - self._retrieve_user( - other_user, - status_code_assertion=status.HTTP_404_NOT_FOUND, - ) + self.client.retrieve(other_user, status.HTTP_404_NOT_FOUND) def test_retrieve__teacher__teacher__not_same_school(self): """ @@ -360,10 +343,7 @@ def test_retrieve__teacher__teacher__not_same_school(self): same_school=False, ) - self._retrieve_user( - other_user, - status_code_assertion=status.HTTP_404_NOT_FOUND, - ) + self.client.retrieve(other_user, status.HTTP_404_NOT_FOUND) def test_retrieve__teacher__student__not_same_school(self): """ @@ -381,10 +361,7 @@ def test_retrieve__teacher__student__not_same_school(self): same_school=False, ) - self._retrieve_user( - other_user, - status_code_assertion=status.HTTP_404_NOT_FOUND, - ) + self.client.retrieve(other_user, status.HTTP_404_NOT_FOUND) def test_retrieve__student__teacher__not_same_school(self): """ @@ -402,10 +379,7 @@ def test_retrieve__student__teacher__not_same_school(self): same_school=False, ) - self._retrieve_user( - other_user, - status_code_assertion=status.HTTP_404_NOT_FOUND, - ) + self.client.retrieve(other_user, status.HTTP_404_NOT_FOUND) def test_retrieve__student__student__not_same_school(self): """ @@ -423,17 +397,14 @@ def test_retrieve__student__student__not_same_school(self): same_school=False, ) - self._retrieve_user( - other_user, - status_code_assertion=status.HTTP_404_NOT_FOUND, - ) + self.client.retrieve(other_user, status.HTTP_404_NOT_FOUND) def test_retrieve__indy_student__teacher(self): """ Independent student cannot retrieve a teacher. """ - user = self._login_indy_student() + user = self._login_indy() other_user = self.get_other_school_user( user, @@ -441,17 +412,14 @@ def test_retrieve__indy_student__teacher(self): is_teacher=True, ) - self._retrieve_user( - other_user, - status_code_assertion=status.HTTP_404_NOT_FOUND, - ) + self.client.retrieve(other_user, status.HTTP_404_NOT_FOUND) def test_retrieve__indy_student__student(self): """ Independent student cannot retrieve a student. """ - user = self._login_indy_student() + user = self._login_indy() other_user = self.get_other_school_user( user, @@ -461,11 +429,9 @@ def test_retrieve__indy_student__student(self): is_teacher=False, ) - self._retrieve_user( - other_user, - status_code_assertion=status.HTTP_404_NOT_FOUND, - ) + self.client.retrieve(other_user, status.HTTP_404_NOT_FOUND) + # pylint: disable-next=pointless-string-statement """ List naming convention: test_list__{user_type}__{filters} @@ -478,20 +444,6 @@ def test_retrieve__indy_student__student(self): filters: Any search params used to dynamically filter the list. """ - def _list_users( - self, - users: t.Iterable[User], - status_code_assertion: APIClient.StatusCodeAssertion = None, - filters: APIClient.ListFilters = None, - ): - return self.client.list( - "user", - users, - UserSerializer, - status_code_assertion, - filters, - ) - def test_list__teacher(self): """ Teacher can list all the users in the same school. @@ -499,7 +451,7 @@ def test_list__teacher(self): user = self._login_teacher() - self._list_users( + self.client.list( User.objects.filter(new_teacher__school=user.teacher.school) | User.objects.filter( new_student__class_field__teacher__school=user.teacher.school, @@ -517,7 +469,7 @@ def test_list__teacher__students_in_class(self): klass = user.teacher.class_teacher.first() assert klass - self._list_users( + self.client.list( User.objects.filter(new_student__class_field=klass), filters={"students_in_class": klass.id}, ) @@ -529,17 +481,18 @@ def test_list__student(self): user = self._login_student() - self._list_users([user]) + self.client.list([user]) def test_list__indy_student(self): """ Independent student can list only themself. """ - user = self._login_indy_student() + user = self._login_indy() - self._list_users([user]) + self.client.list([user]) + # pylint: disable-next=pointless-string-statement """ General tests that apply to all actions. """