Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: model view set test case #62

Merged
merged 7 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 63 additions & 66 deletions codeforlife/tests/model_view_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,16 @@
from pyotp import TOTP
from rest_framework import status
from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer
from rest_framework.test import APIClient, APIRequestFactory, APITestCase
from rest_framework.viewsets import ModelViewSet

from ..serializers import ModelSerializer
from ..user.models import AuthFactor, User
from ..views import ModelViewSet

AnyModelViewSet = t.TypeVar("AnyModelViewSet", bound=ModelViewSet)
AnyModelSerializer = t.TypeVar("AnyModelSerializer", bound=ModelSerializer)
AnyModel = t.TypeVar("AnyModel", bound=Model)


class ModelViewSetClient(
APIClient,
t.Generic[AnyModelViewSet, AnyModelSerializer, AnyModel],
):
class ModelViewSetClient(APIClient, t.Generic[AnyModel]):
"""
An API client that helps make requests to a model view set and assert their
responses.
Expand All @@ -44,7 +39,7 @@ def __init__(self, enforce_csrf_checks: bool = False, **defaults):
**defaults,
)

_test_case: "ModelViewSetTestCase[AnyModelViewSet, AnyModelSerializer, AnyModel]"
_test_case: "ModelViewSetTestCase[AnyModel]"

@property
def _model_class(self):
Expand All @@ -53,19 +48,12 @@ def _model_class(self):
# pylint: disable-next=no-member
return self._test_case.get_model_class()

@property
def _model_serializer_class(self):
"""Shortcut to get model serializer class."""

# pylint: disable-next=no-member
return self._test_case.get_model_serializer_class()

@property
def _model_view_set_class(self):
"""Shortcut to get model view set class."""

# pylint: disable-next=no-member
return self._test_case.get_model_view_set_class()
return self._test_case.model_view_set_class

Data = t.Dict[str, t.Any]
StatusCodeAssertion = t.Optional[t.Union[int, t.Callable[[int], bool]]]
Expand All @@ -89,6 +77,9 @@ def assert_data_equals_model(
self,
data: Data,
model: AnyModel,
model_serializer_class: t.Optional[
t.Type[ModelSerializer[AnyModel]]
] = None,
contains_subset: bool = False,
):
# pylint: disable=line-too-long
Expand All @@ -115,7 +106,14 @@ def parse_data(data):
return data.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
return data

actual_data = parse_data(self._model_serializer_class(model).data)
if model_serializer_class is None:
model_serializer_class = (
# pylint: disable-next=no-member
self._test_case.model_serializer_class
or self._model_view_set_class().get_serializer_class()
)

actual_data = parse_data(model_serializer_class(model).data)

if contains_subset:
# pylint: disable-next=no-member
Expand Down Expand Up @@ -244,17 +242,23 @@ def retrieve(
self,
model: AnyModel,
status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK,
model_serializer_class: t.Optional[
t.Type[ModelSerializer[AnyModel]]
] = None,
**kwargs,
):
# pylint: disable=line-too-long
"""Retrieve a model.

Args:
model: The model to retrieve.
status_code_assertion: The expected status code.
model_serializer_class: The serializer used to serialize the model's data.

Returns:
The HTTP response.
"""
# pylint: enable=line-too-long

response: Response = self.get(
self.reverse("detail", model),
Expand All @@ -266,6 +270,7 @@ def retrieve(
self.assert_data_equals_model(
response.json(), # type: ignore[attr-defined]
model,
model_serializer_class,
)

return response
Expand All @@ -274,19 +279,25 @@ def list(
self,
models: t.Iterable[AnyModel],
status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK,
model_serializer_class: t.Optional[
t.Type[ModelSerializer[AnyModel]]
] = None,
filters: ListFilters = None,
**kwargs,
):
# pylint: disable=line-too-long
"""Retrieve a list of models.

Args:
models: The model list to retrieve.
status_code_assertion: The expected status code.
model_serializer_class: The serializer used to serialize the model's data.
filters: The filters to apply to the list.

Returns:
The HTTP response.
"""
# pylint: enable=line-too-long

assert self._model_class.objects.difference(
self._model_class.objects.filter(
Expand All @@ -301,8 +312,15 @@ def list(
)

if self.status_code_is_ok(response.status_code):
for data, model in zip(response.json()["data"], models): # type: ignore[attr-defined]
self.assert_data_equals_model(data, model)
for data, model in zip(
response.json()["data"], # type: ignore[attr-defined]
models,
):
self.assert_data_equals_model(
data,
model,
model_serializer_class,
)

return response

Expand All @@ -311,18 +329,24 @@ def partial_update(
model: AnyModel,
data: Data,
status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK,
model_serializer_class: t.Optional[
t.Type[ModelSerializer[AnyModel]]
] = None,
**kwargs,
):
# pylint: disable=line-too-long
"""Partially update a model.

Args:
model: The model to partially update.
data: The values for each field.
status_code_assertion: The expected status code.
model_serializer_class: The serializer used to serialize the model's data.

Returns:
The HTTP response.
"""
# pylint: enable=line-too-long

response: Response = self.patch(
self.reverse("detail", model),
Expand All @@ -336,6 +360,7 @@ def partial_update(
self.assert_data_equals_model(
response.json(), # type: ignore[attr-defined]
model,
model_serializer_class,
contains_subset=True,
)

Expand Down Expand Up @@ -367,7 +392,9 @@ def destroy(

if not anonymized and self.status_code_is_ok(response.status_code):
# pylint: disable-next=no-member
with self._test_case.assertRaises(model.DoesNotExist):
with self._test_case.assertRaises(
model.DoesNotExist # type: ignore[attr-defined]
):
model.refresh_from_db()

return response
Expand Down Expand Up @@ -397,20 +424,23 @@ def login(self, **credentials):

return user

def login_teacher(self, is_admin: bool, **credentials):
def login_teacher(self, is_admin: t.Optional[bool] = None, **credentials):
# pylint: disable=line-too-long
"""Log in a user and assert they are a teacher.

Args:
is_admin: Whether or not the teacher is an admin.
is_admin: Whether or not the teacher is an admin. Set none if a teacher can be either or.

Returns:
The teacher-user.
"""
# pylint: enable=line-too-long

user = self.login(**credentials)
assert user.teacher
assert user.teacher.school
assert is_admin == user.teacher.is_admin
if is_admin is not None:
assert is_admin == user.teacher.is_admin
return user

def login_student(self, **credentials):
Expand Down Expand Up @@ -438,56 +468,20 @@ def login_indy(self, **credentials):
return user


class ModelViewSetTestCase(
APITestCase,
t.Generic[AnyModelViewSet, AnyModelSerializer, AnyModel],
):
class ModelViewSetTestCase(APITestCase, t.Generic[AnyModel]):
"""Base for all model view set test cases."""

basename: str
client: ModelViewSetClient[ # type: ignore[assignment]
AnyModelViewSet,
AnyModelSerializer,
AnyModel,
]
model_view_set_class: t.Type[ModelViewSet[AnyModel]]
model_serializer_class: t.Optional[t.Type[ModelSerializer[AnyModel]]] = None
client: ModelViewSetClient[AnyModel]
client_class = ModelViewSetClient # type: ignore[assignment]

def _pre_setup(self):
super()._pre_setup()
super()._pre_setup() # type: ignore[misc]
# pylint: disable-next=protected-access
self.client._test_case = self

@classmethod
def _get_generic_args(
cls,
) -> t.Tuple[
t.Type[AnyModelViewSet],
t.Type[AnyModelSerializer],
t.Type[AnyModel],
]:
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0]) # type: ignore[attr-defined,return-value]

@classmethod
def get_model_view_set_class(cls):
"""Get the model view set's class.

Returns:
The model view set's class.
"""

return cls._get_generic_args()[0]

@classmethod
def get_model_serializer_class(cls):
"""Get the model serializer's class.

Returns:
The model serializer's class.
"""

return cls._get_generic_args()[1]

@classmethod
def get_model_class(cls):
"""Get the model view set's class.
Expand All @@ -496,7 +490,10 @@ def get_model_class(cls):
The model view set's class.
"""

return cls._get_generic_args()[2]
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]

def get_other_user(
self,
Expand Down
21 changes: 9 additions & 12 deletions codeforlife/user/migrations/0001_initial.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generated by Django 3.2.20 on 2023-09-29 17:53
# Generated by Django 3.2.20 on 2024-01-24 18:42

import django.contrib.auth.models
import django.core.validators
Expand Down Expand Up @@ -43,6 +43,14 @@ class Migration(migrations.Migration):
'abstract': False,
},
),
migrations.CreateModel(
name='OtpBypassToken',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('token', models.CharField(max_length=8, validators=[django.core.validators.MinLengthValidator(8)])),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='otp_bypass_tokens', to='user.user')),
],
),
migrations.CreateModel(
name='AuthFactor',
fields=[
Expand All @@ -65,15 +73,4 @@ class Migration(migrations.Migration):
'unique_together': {('session', 'auth_factor')},
},
),
migrations.CreateModel(
name='OtpBypassToken',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('token', models.CharField(max_length=8, validators=[django.core.validators.MinLengthValidator(8)])),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='otp_bypass_tokens', to='user.user')),
],
options={
'unique_together': {('user', 'token')},
},
),
]
29 changes: 23 additions & 6 deletions codeforlife/user/models/otp_bypass_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
from django.core.exceptions import ValidationError
from django.core.validators import MinLengthValidator
from django.db import models
from django.utils.crypto import get_random_string

from . import user


class OtpBypassToken(models.Model):
length = 8
allowed_chars = "abcdefghijklmnopqrstuvwxyz"
max_count = 10
max_count_validation_error = ValidationError(
f"Exceeded max count of {max_count}"
Expand Down Expand Up @@ -51,13 +54,10 @@ def key(otp_bypass_token: OtpBypassToken):
)

token = models.CharField(
max_length=8,
validators=[MinLengthValidator(8)],
max_length=length,
validators=[MinLengthValidator(length)],
)

class Meta:
unique_together = ["user", "token"]

def save(self, *args, **kwargs):
if self.id is None:
if (
Expand All @@ -69,7 +69,24 @@ def save(self, *args, **kwargs):
return super().save(*args, **kwargs)

def check_token(self, token: str):
if check_password(token, self.token):
if check_password(token.lower(), self.token):
self.delete()
return True
return False

@classmethod
def generate_tokens(cls, count: int = max_count):
"""Generates a number of tokens.

Args:
count: The number of tokens to generate. Default to max.

Returns:
Raw tokens that are random and unique.
"""

tokens: t.Set[str] = set()
while len(tokens) < count:
tokens.add(get_random_string(cls.length, cls.allowed_chars))

return tokens
Loading
Loading