Skip to content

Commit

Permalink
fix: model view set test case (#62)
Browse files Browse the repository at this point in the history
* fix: model view set test case

* fix otp bypass token model

* fix test

* add type hint for serializer class

* make is_admin optional

* add permissions

* wxyz
  • Loading branch information
SKairinos authored Jan 25, 2024
1 parent 77d6413 commit 8288825
Show file tree
Hide file tree
Showing 24 changed files with 484 additions and 192 deletions.
129 changes: 63 additions & 66 deletions codeforlife/tests/model_view_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,16 @@
from pyotp import TOTP
from rest_framework import status
from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer
from rest_framework.test import APIClient, APIRequestFactory, APITestCase
from rest_framework.viewsets import ModelViewSet

from ..serializers import ModelSerializer
from ..user.models import AuthFactor, User
from ..views import ModelViewSet

AnyModelViewSet = t.TypeVar("AnyModelViewSet", bound=ModelViewSet)
AnyModelSerializer = t.TypeVar("AnyModelSerializer", bound=ModelSerializer)
AnyModel = t.TypeVar("AnyModel", bound=Model)


class ModelViewSetClient(
APIClient,
t.Generic[AnyModelViewSet, AnyModelSerializer, AnyModel],
):
class ModelViewSetClient(APIClient, t.Generic[AnyModel]):
"""
An API client that helps make requests to a model view set and assert their
responses.
Expand All @@ -44,7 +39,7 @@ def __init__(self, enforce_csrf_checks: bool = False, **defaults):
**defaults,
)

_test_case: "ModelViewSetTestCase[AnyModelViewSet, AnyModelSerializer, AnyModel]"
_test_case: "ModelViewSetTestCase[AnyModel]"

@property
def _model_class(self):
Expand All @@ -53,19 +48,12 @@ def _model_class(self):
# pylint: disable-next=no-member
return self._test_case.get_model_class()

@property
def _model_serializer_class(self):
"""Shortcut to get model serializer class."""

# pylint: disable-next=no-member
return self._test_case.get_model_serializer_class()

@property
def _model_view_set_class(self):
"""Shortcut to get model view set class."""

# pylint: disable-next=no-member
return self._test_case.get_model_view_set_class()
return self._test_case.model_view_set_class

Data = t.Dict[str, t.Any]
StatusCodeAssertion = t.Optional[t.Union[int, t.Callable[[int], bool]]]
Expand All @@ -89,6 +77,9 @@ def assert_data_equals_model(
self,
data: Data,
model: AnyModel,
model_serializer_class: t.Optional[
t.Type[ModelSerializer[AnyModel]]
] = None,
contains_subset: bool = False,
):
# pylint: disable=line-too-long
Expand All @@ -115,7 +106,14 @@ def parse_data(data):
return data.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
return data

actual_data = parse_data(self._model_serializer_class(model).data)
if model_serializer_class is None:
model_serializer_class = (
# pylint: disable-next=no-member
self._test_case.model_serializer_class
or self._model_view_set_class().get_serializer_class()
)

actual_data = parse_data(model_serializer_class(model).data)

if contains_subset:
# pylint: disable-next=no-member
Expand Down Expand Up @@ -244,17 +242,23 @@ def retrieve(
self,
model: AnyModel,
status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK,
model_serializer_class: t.Optional[
t.Type[ModelSerializer[AnyModel]]
] = None,
**kwargs,
):
# pylint: disable=line-too-long
"""Retrieve a model.
Args:
model: The model to retrieve.
status_code_assertion: The expected status code.
model_serializer_class: The serializer used to serialize the model's data.
Returns:
The HTTP response.
"""
# pylint: enable=line-too-long

response: Response = self.get(
self.reverse("detail", model),
Expand All @@ -266,6 +270,7 @@ def retrieve(
self.assert_data_equals_model(
response.json(), # type: ignore[attr-defined]
model,
model_serializer_class,
)

return response
Expand All @@ -274,19 +279,25 @@ def list(
self,
models: t.Iterable[AnyModel],
status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK,
model_serializer_class: t.Optional[
t.Type[ModelSerializer[AnyModel]]
] = None,
filters: ListFilters = None,
**kwargs,
):
# pylint: disable=line-too-long
"""Retrieve a list of models.
Args:
models: The model list to retrieve.
status_code_assertion: The expected status code.
model_serializer_class: The serializer used to serialize the model's data.
filters: The filters to apply to the list.
Returns:
The HTTP response.
"""
# pylint: enable=line-too-long

assert self._model_class.objects.difference(
self._model_class.objects.filter(
Expand All @@ -301,8 +312,15 @@ def list(
)

if self.status_code_is_ok(response.status_code):
for data, model in zip(response.json()["data"], models): # type: ignore[attr-defined]
self.assert_data_equals_model(data, model)
for data, model in zip(
response.json()["data"], # type: ignore[attr-defined]
models,
):
self.assert_data_equals_model(
data,
model,
model_serializer_class,
)

return response

Expand All @@ -311,18 +329,24 @@ def partial_update(
model: AnyModel,
data: Data,
status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK,
model_serializer_class: t.Optional[
t.Type[ModelSerializer[AnyModel]]
] = None,
**kwargs,
):
# pylint: disable=line-too-long
"""Partially update a model.
Args:
model: The model to partially update.
data: The values for each field.
status_code_assertion: The expected status code.
model_serializer_class: The serializer used to serialize the model's data.
Returns:
The HTTP response.
"""
# pylint: enable=line-too-long

response: Response = self.patch(
self.reverse("detail", model),
Expand All @@ -336,6 +360,7 @@ def partial_update(
self.assert_data_equals_model(
response.json(), # type: ignore[attr-defined]
model,
model_serializer_class,
contains_subset=True,
)

Expand Down Expand Up @@ -367,7 +392,9 @@ def destroy(

if not anonymized and self.status_code_is_ok(response.status_code):
# pylint: disable-next=no-member
with self._test_case.assertRaises(model.DoesNotExist):
with self._test_case.assertRaises(
model.DoesNotExist # type: ignore[attr-defined]
):
model.refresh_from_db()

return response
Expand Down Expand Up @@ -397,20 +424,23 @@ def login(self, **credentials):

return user

def login_teacher(self, is_admin: bool, **credentials):
def login_teacher(self, is_admin: t.Optional[bool] = None, **credentials):
# pylint: disable=line-too-long
"""Log in a user and assert they are a teacher.
Args:
is_admin: Whether or not the teacher is an admin.
is_admin: Whether or not the teacher is an admin. Set none if a teacher can be either or.
Returns:
The teacher-user.
"""
# pylint: enable=line-too-long

user = self.login(**credentials)
assert user.teacher
assert user.teacher.school
assert is_admin == user.teacher.is_admin
if is_admin is not None:
assert is_admin == user.teacher.is_admin
return user

def login_student(self, **credentials):
Expand Down Expand Up @@ -438,56 +468,20 @@ def login_indy(self, **credentials):
return user


class ModelViewSetTestCase(
APITestCase,
t.Generic[AnyModelViewSet, AnyModelSerializer, AnyModel],
):
class ModelViewSetTestCase(APITestCase, t.Generic[AnyModel]):
"""Base for all model view set test cases."""

basename: str
client: ModelViewSetClient[ # type: ignore[assignment]
AnyModelViewSet,
AnyModelSerializer,
AnyModel,
]
model_view_set_class: t.Type[ModelViewSet[AnyModel]]
model_serializer_class: t.Optional[t.Type[ModelSerializer[AnyModel]]] = None
client: ModelViewSetClient[AnyModel]
client_class = ModelViewSetClient # type: ignore[assignment]

def _pre_setup(self):
super()._pre_setup()
super()._pre_setup() # type: ignore[misc]
# pylint: disable-next=protected-access
self.client._test_case = self

@classmethod
def _get_generic_args(
cls,
) -> t.Tuple[
t.Type[AnyModelViewSet],
t.Type[AnyModelSerializer],
t.Type[AnyModel],
]:
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0]) # type: ignore[attr-defined,return-value]

@classmethod
def get_model_view_set_class(cls):
"""Get the model view set's class.
Returns:
The model view set's class.
"""

return cls._get_generic_args()[0]

@classmethod
def get_model_serializer_class(cls):
"""Get the model serializer's class.
Returns:
The model serializer's class.
"""

return cls._get_generic_args()[1]

@classmethod
def get_model_class(cls):
"""Get the model view set's class.
Expand All @@ -496,7 +490,10 @@ def get_model_class(cls):
The model view set's class.
"""

return cls._get_generic_args()[2]
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]

def get_other_user(
self,
Expand Down
21 changes: 9 additions & 12 deletions codeforlife/user/migrations/0001_initial.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generated by Django 3.2.20 on 2023-09-29 17:53
# Generated by Django 3.2.20 on 2024-01-24 18:42

import django.contrib.auth.models
import django.core.validators
Expand Down Expand Up @@ -43,6 +43,14 @@ class Migration(migrations.Migration):
'abstract': False,
},
),
migrations.CreateModel(
name='OtpBypassToken',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('token', models.CharField(max_length=8, validators=[django.core.validators.MinLengthValidator(8)])),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='otp_bypass_tokens', to='user.user')),
],
),
migrations.CreateModel(
name='AuthFactor',
fields=[
Expand All @@ -65,15 +73,4 @@ class Migration(migrations.Migration):
'unique_together': {('session', 'auth_factor')},
},
),
migrations.CreateModel(
name='OtpBypassToken',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('token', models.CharField(max_length=8, validators=[django.core.validators.MinLengthValidator(8)])),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='otp_bypass_tokens', to='user.user')),
],
options={
'unique_together': {('user', 'token')},
},
),
]
29 changes: 23 additions & 6 deletions codeforlife/user/models/otp_bypass_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
from django.core.exceptions import ValidationError
from django.core.validators import MinLengthValidator
from django.db import models
from django.utils.crypto import get_random_string

from . import user


class OtpBypassToken(models.Model):
length = 8
allowed_chars = "abcdefghijklmnopqrstuvwxyz"
max_count = 10
max_count_validation_error = ValidationError(
f"Exceeded max count of {max_count}"
Expand Down Expand Up @@ -51,13 +54,10 @@ def key(otp_bypass_token: OtpBypassToken):
)

token = models.CharField(
max_length=8,
validators=[MinLengthValidator(8)],
max_length=length,
validators=[MinLengthValidator(length)],
)

class Meta:
unique_together = ["user", "token"]

def save(self, *args, **kwargs):
if self.id is None:
if (
Expand All @@ -69,7 +69,24 @@ def save(self, *args, **kwargs):
return super().save(*args, **kwargs)

def check_token(self, token: str):
if check_password(token, self.token):
if check_password(token.lower(), self.token):
self.delete()
return True
return False

@classmethod
def generate_tokens(cls, count: int = max_count):
"""Generates a number of tokens.
Args:
count: The number of tokens to generate. Default to max.
Returns:
Raw tokens that are random and unique.
"""

tokens: t.Set[str] = set()
while len(tokens) < count:
tokens.add(get_random_string(cls.length, cls.allowed_chars))

return tokens
Loading

0 comments on commit 8288825

Please sign in to comment.