Skip to content

Commit

Permalink
quick save
Browse files Browse the repository at this point in the history
  • Loading branch information
SKairinos committed Jan 19, 2024
1 parent 809269a commit 9c0ae8d
Show file tree
Hide file tree
Showing 2 changed files with 368 additions and 87 deletions.
328 changes: 328 additions & 0 deletions codeforlife/tests/model_view_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,328 @@
"""
© Ocado Group
Created on 19/01/2024 at 17:06:45(+00:00).
Base test case for all model view sets.
"""

import typing as t
from datetime import datetime
from unittest.mock import patch

from django.db.models import Model
from django.urls import reverse
from django.utils import timezone
from django.utils.http import urlencode
from pyotp import TOTP
from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer
from rest_framework.test import APIClient, APITestCase
from rest_framework.viewsets import ModelViewSet

from ..user.models import AuthFactor, User

AnyModelViewSet = t.TypeVar("AnyModelViewSet", bound=ModelViewSet)
AnyModelSerializer = t.TypeVar("AnyModelSerializer", bound=ModelSerializer)
AnyModel = t.TypeVar("AnyModel", bound=Model)


class ModelViewSetClient(
APIClient,
t.Generic[AnyModelViewSet, AnyModelSerializer, AnyModel],
):
"""
An API client that helps make requests to a model view set and assert their
responses.
"""

basename: str
model_class: t.Type[AnyModel]
model_serializer_class: t.Type[AnyModelSerializer]
model_view_set_class: t.Type[AnyModelViewSet]

StatusCodeAssertion = t.Optional[t.Union[int, t.Callable[[int], bool]]]
ListFilters = t.Optional[t.Dict[str, str]]

@staticmethod
def status_code_is_ok(status_code: int):
"""Check if the status code is greater than or equal to 200 and less
than 300.
Args:
status_code: The status code to check.
Returns:
A flag designating if the status code is OK.
"""

return 200 <= status_code < 300

def assert_data_equals_model(
self,
data: t.Dict[str, t.Any],
model: AnyModel,
):
"""Check if the data equals the current state of the model instance.
Args:
data: The data to check.
model: The model instance.
model_serializer_class: The serializer used to serialize the model's data.
Returns:
A flag designating if the data equals the current state of the model
instance.
"""

def parse_data(data):
if isinstance(data, list):
return [parse_data(value) for value in data]
if isinstance(data, dict):
return {key: parse_data(value) for key, value in data.items()}
if isinstance(data, datetime):
return data.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
return data

assert data == parse_data(
self.model_serializer_class(model).data
), "Data does not equal serialized model."

# pylint: disable-next=too-many-arguments
def generic(
self,
method,
path,
data="",
content_type="application/octet-stream",
secure=False,
status_code_assertion: StatusCodeAssertion = None,
**extra,
):
response = t.cast(
Response,
super().generic(
method,
path,
data,
content_type,
secure,
**extra,
),
)

# Use a custom kwarg to handle the common case of checking the
# response's status code.
if status_code_assertion is None:
status_code_assertion = self.status_code_is_ok
elif isinstance(status_code_assertion, int):
expected_status_code = status_code_assertion
status_code_assertion = (
# pylint: disable-next=unnecessary-lambda-assignment
lambda status_code: status_code
== expected_status_code
)

# pylint: disable-next=no-member
status_code = response.status_code
assert status_code_assertion(
status_code
), f"Unexpected status code: {status_code}."

return response

def retrieve(
self,
model: AnyModel,
status_code_assertion: StatusCodeAssertion = None,
**kwargs,
):
"""Retrieve a model from the view set.
Args:
model: The model to retrieve.
status_code_assertion: The expected status code.
Returns:
The HTTP response.
"""

response: Response = self.get(
reverse(
f"{self.basename}-detail",
kwargs={
self.model_view_set_class.lookup_field: getattr(
model, self.model_view_set_class.lookup_field
)
},
),
status_code_assertion=status_code_assertion,
**kwargs,
)

if self.status_code_is_ok(response.status_code):
self.assert_data_equals_model(
response.json(), # type: ignore[attr-defined]
model,
)

return response

def list(
self,
models: t.Iterable[AnyModel],
status_code_assertion: StatusCodeAssertion = None,
filters: ListFilters = None,
**kwargs,
):
"""Retrieve a list of models from the view set.
Args:
models: The model list to retrieve.
status_code_assertion: The expected status code.
filters: The filters to apply to the list.
Returns:
The HTTP response.
"""

assert self.model_class.objects.difference(
self.model_class.objects.filter(
pk__in=[model.pk for model in models]
)
).exists(), "List must exclude some models for a valid test."

response: Response = self.get(
f"{reverse(f'{self.basename}-list')}?{urlencode(filters or {})}",
status_code_assertion=status_code_assertion,
**kwargs,
)

if self.status_code_is_ok(response.status_code):
for data, model in zip(response.json()["data"], models): # type: ignore[attr-defined]
self.assert_data_equals_model(data, model)

return response

def login(self, **credentials):
assert super().login(
**credentials
), f"Failed to login with credentials: {credentials}."

user = User.objects.get(session=self.session.session_key)

if user.session.session_auth_factors.filter(
auth_factor__type=AuthFactor.Type.OTP
).exists():
now = timezone.now()
otp = TOTP(user.otp_secret).at(now)
with patch.object(timezone, "now", return_value=now):
assert super().login(
otp=otp
), f'Failed to login with OTP "{otp}" at {now}.'

assert user.is_authenticated, "Failed to authenticate user."

return user

def login_teacher(self, is_admin: bool, **credentials):
"""Log in a user and assert they are a teacher.
Args:
is_admin: Whether or not the teacher is an admin.
Returns:
The teacher-user.
"""

user = self.login(**credentials)
assert user.teacher
assert user.teacher.school
assert is_admin == user.teacher.is_admin
return user

def login_student(self, **credentials):
"""Log in a user and assert they are a student.
Returns:
The student-user.
"""

user = self.login(**credentials)
assert user.student
assert user.student.class_field.teacher.school
return user

def login_indy(self, **credentials):
"""Log in an independent and assert they are a student.
Returns:
The independent-user.
"""

user = self.login(**credentials)
assert user.student
assert not user.student.class_field
return user


class ModelViewSetTestCase(
APITestCase,
t.Generic[AnyModelViewSet, AnyModelSerializer, AnyModel],
):
"""Base for all model view set test cases."""

basename: str
client: ModelViewSetClient[ # type: ignore[assignment]
AnyModelViewSet,
AnyModelSerializer,
AnyModel,
]
client_class = ModelViewSetClient # type: ignore[assignment]

def _pre_setup(self):
super()._pre_setup()
self.client.basename = self.basename
self.client.model_view_set_class = self.get_model_view_set_class()
self.client.model_serializer_class = self.get_model_serializer_class()
self.client.model_class = self.get_model_class()

@classmethod
def _get_generic_args(
cls,
) -> t.Tuple[
t.Type[AnyModelViewSet],
t.Type[AnyModelSerializer],
t.Type[AnyModel],
]:
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0]) # type: ignore[attr-defined,return-value]

@classmethod
def get_model_view_set_class(cls):
"""Get the model view set's class.
Returns:
The model view set's class.
"""

return cls._get_generic_args()[0]

@classmethod
def get_model_serializer_class(cls):
"""Get the model serializer's class.
Returns:
The model serializer's class.
"""

return cls._get_generic_args()[1]

@classmethod
def get_model_class(cls):
"""Get the model view set's class.
Returns:
The model view set's class.
"""

return cls._get_generic_args()[2]
Loading

0 comments on commit 9c0ae8d

Please sign in to comment.