diff --git a/codeforlife/tests/model_view_set.py b/codeforlife/tests/model_view_set.py index 6140585f..5c99fc9b 100644 --- a/codeforlife/tests/model_view_set.py +++ b/codeforlife/tests/model_view_set.py @@ -17,21 +17,16 @@ from pyotp import TOTP from rest_framework import status from rest_framework.response import Response -from rest_framework.serializers import ModelSerializer from rest_framework.test import APIClient, APIRequestFactory, APITestCase -from rest_framework.viewsets import ModelViewSet +from ..serializers import ModelSerializer from ..user.models import AuthFactor, User +from ..views import ModelViewSet -AnyModelViewSet = t.TypeVar("AnyModelViewSet", bound=ModelViewSet) -AnyModelSerializer = t.TypeVar("AnyModelSerializer", bound=ModelSerializer) AnyModel = t.TypeVar("AnyModel", bound=Model) -class ModelViewSetClient( - APIClient, - t.Generic[AnyModelViewSet, AnyModelSerializer, AnyModel], -): +class ModelViewSetClient(APIClient, t.Generic[AnyModel]): """ An API client that helps make requests to a model view set and assert their responses. @@ -44,7 +39,7 @@ def __init__(self, enforce_csrf_checks: bool = False, **defaults): **defaults, ) - _test_case: "ModelViewSetTestCase[AnyModelViewSet, AnyModelSerializer, AnyModel]" + _test_case: "ModelViewSetTestCase[AnyModel]" @property def _model_class(self): @@ -53,19 +48,12 @@ def _model_class(self): # pylint: disable-next=no-member return self._test_case.get_model_class() - @property - def _model_serializer_class(self): - """Shortcut to get model serializer class.""" - - # pylint: disable-next=no-member - return self._test_case.get_model_serializer_class() - @property def _model_view_set_class(self): """Shortcut to get model view set class.""" # pylint: disable-next=no-member - return self._test_case.get_model_view_set_class() + return self._test_case.model_view_set_class Data = t.Dict[str, t.Any] StatusCodeAssertion = t.Optional[t.Union[int, t.Callable[[int], bool]]] @@ -89,6 +77,9 @@ def assert_data_equals_model( self, data: Data, model: AnyModel, + model_serializer_class: t.Optional[ + t.Type[ModelSerializer[AnyModel]] + ] = None, contains_subset: bool = False, ): # pylint: disable=line-too-long @@ -115,7 +106,14 @@ def parse_data(data): return data.strftime("%Y-%m-%dT%H:%M:%S.%fZ") return data - actual_data = parse_data(self._model_serializer_class(model).data) + if model_serializer_class is None: + model_serializer_class = ( + # pylint: disable-next=no-member + self._test_case.model_serializer_class + or self._model_view_set_class().get_serializer_class() + ) + + actual_data = parse_data(model_serializer_class(model).data) if contains_subset: # pylint: disable-next=no-member @@ -244,17 +242,23 @@ def retrieve( self, model: AnyModel, status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK, + model_serializer_class: t.Optional[ + t.Type[ModelSerializer[AnyModel]] + ] = None, **kwargs, ): + # pylint: disable=line-too-long """Retrieve a model. Args: model: The model to retrieve. status_code_assertion: The expected status code. + model_serializer_class: The serializer used to serialize the model's data. Returns: The HTTP response. """ + # pylint: enable=line-too-long response: Response = self.get( self.reverse("detail", model), @@ -266,6 +270,7 @@ def retrieve( self.assert_data_equals_model( response.json(), # type: ignore[attr-defined] model, + model_serializer_class, ) return response @@ -274,19 +279,25 @@ def list( self, models: t.Iterable[AnyModel], status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK, + model_serializer_class: t.Optional[ + t.Type[ModelSerializer[AnyModel]] + ] = None, filters: ListFilters = None, **kwargs, ): + # pylint: disable=line-too-long """Retrieve a list of models. Args: models: The model list to retrieve. status_code_assertion: The expected status code. + model_serializer_class: The serializer used to serialize the model's data. filters: The filters to apply to the list. Returns: The HTTP response. """ + # pylint: enable=line-too-long assert self._model_class.objects.difference( self._model_class.objects.filter( @@ -301,8 +312,15 @@ def list( ) if self.status_code_is_ok(response.status_code): - for data, model in zip(response.json()["data"], models): # type: ignore[attr-defined] - self.assert_data_equals_model(data, model) + for data, model in zip( + response.json()["data"], # type: ignore[attr-defined] + models, + ): + self.assert_data_equals_model( + data, + model, + model_serializer_class, + ) return response @@ -311,18 +329,24 @@ def partial_update( model: AnyModel, data: Data, status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK, + model_serializer_class: t.Optional[ + t.Type[ModelSerializer[AnyModel]] + ] = None, **kwargs, ): + # pylint: disable=line-too-long """Partially update a model. Args: model: The model to partially update. data: The values for each field. status_code_assertion: The expected status code. + model_serializer_class: The serializer used to serialize the model's data. Returns: The HTTP response. """ + # pylint: enable=line-too-long response: Response = self.patch( self.reverse("detail", model), @@ -336,6 +360,7 @@ def partial_update( self.assert_data_equals_model( response.json(), # type: ignore[attr-defined] model, + model_serializer_class, contains_subset=True, ) @@ -367,7 +392,9 @@ def destroy( if not anonymized and self.status_code_is_ok(response.status_code): # pylint: disable-next=no-member - with self._test_case.assertRaises(model.DoesNotExist): + with self._test_case.assertRaises( + model.DoesNotExist # type: ignore[attr-defined] + ): model.refresh_from_db() return response @@ -397,20 +424,23 @@ def login(self, **credentials): return user - def login_teacher(self, is_admin: bool, **credentials): + def login_teacher(self, is_admin: t.Optional[bool] = None, **credentials): + # pylint: disable=line-too-long """Log in a user and assert they are a teacher. Args: - is_admin: Whether or not the teacher is an admin. + is_admin: Whether or not the teacher is an admin. Set none if a teacher can be either or. Returns: The teacher-user. """ + # pylint: enable=line-too-long user = self.login(**credentials) assert user.teacher assert user.teacher.school - assert is_admin == user.teacher.is_admin + if is_admin is not None: + assert is_admin == user.teacher.is_admin return user def login_student(self, **credentials): @@ -438,56 +468,20 @@ def login_indy(self, **credentials): return user -class ModelViewSetTestCase( - APITestCase, - t.Generic[AnyModelViewSet, AnyModelSerializer, AnyModel], -): +class ModelViewSetTestCase(APITestCase, t.Generic[AnyModel]): """Base for all model view set test cases.""" basename: str - client: ModelViewSetClient[ # type: ignore[assignment] - AnyModelViewSet, - AnyModelSerializer, - AnyModel, - ] + model_view_set_class: t.Type[ModelViewSet[AnyModel]] + model_serializer_class: t.Optional[t.Type[ModelSerializer[AnyModel]]] = None + client: ModelViewSetClient[AnyModel] client_class = ModelViewSetClient # type: ignore[assignment] def _pre_setup(self): - super()._pre_setup() + super()._pre_setup() # type: ignore[misc] # pylint: disable-next=protected-access self.client._test_case = self - @classmethod - def _get_generic_args( - cls, - ) -> t.Tuple[ - t.Type[AnyModelViewSet], - t.Type[AnyModelSerializer], - t.Type[AnyModel], - ]: - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0]) # type: ignore[attr-defined,return-value] - - @classmethod - def get_model_view_set_class(cls): - """Get the model view set's class. - - Returns: - The model view set's class. - """ - - return cls._get_generic_args()[0] - - @classmethod - def get_model_serializer_class(cls): - """Get the model serializer's class. - - Returns: - The model serializer's class. - """ - - return cls._get_generic_args()[1] - @classmethod def get_model_class(cls): """Get the model view set's class. @@ -496,7 +490,10 @@ def get_model_class(cls): The model view set's class. """ - return cls._get_generic_args()[2] + # pylint: disable-next=no-member + return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] + 0 + ] def get_other_user( self, diff --git a/codeforlife/user/migrations/0001_initial.py b/codeforlife/user/migrations/0001_initial.py index 034c7cbe..d27af2a9 100644 --- a/codeforlife/user/migrations/0001_initial.py +++ b/codeforlife/user/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 3.2.20 on 2023-09-29 17:53 +# Generated by Django 3.2.20 on 2024-01-24 18:42 import django.contrib.auth.models import django.core.validators @@ -43,6 +43,14 @@ class Migration(migrations.Migration): 'abstract': False, }, ), + migrations.CreateModel( + name='OtpBypassToken', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('token', models.CharField(max_length=8, validators=[django.core.validators.MinLengthValidator(8)])), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='otp_bypass_tokens', to='user.user')), + ], + ), migrations.CreateModel( name='AuthFactor', fields=[ @@ -65,15 +73,4 @@ class Migration(migrations.Migration): 'unique_together': {('session', 'auth_factor')}, }, ), - migrations.CreateModel( - name='OtpBypassToken', - fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('token', models.CharField(max_length=8, validators=[django.core.validators.MinLengthValidator(8)])), - ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='otp_bypass_tokens', to='user.user')), - ], - options={ - 'unique_together': {('user', 'token')}, - }, - ), ] diff --git a/codeforlife/user/models/otp_bypass_token.py b/codeforlife/user/models/otp_bypass_token.py index fbd22ad8..2692a367 100644 --- a/codeforlife/user/models/otp_bypass_token.py +++ b/codeforlife/user/models/otp_bypass_token.py @@ -5,11 +5,14 @@ from django.core.exceptions import ValidationError from django.core.validators import MinLengthValidator from django.db import models +from django.utils.crypto import get_random_string from . import user class OtpBypassToken(models.Model): + length = 8 + allowed_chars = "abcdefghijklmnopqrstuvwxyz" max_count = 10 max_count_validation_error = ValidationError( f"Exceeded max count of {max_count}" @@ -51,13 +54,10 @@ def key(otp_bypass_token: OtpBypassToken): ) token = models.CharField( - max_length=8, - validators=[MinLengthValidator(8)], + max_length=length, + validators=[MinLengthValidator(length)], ) - class Meta: - unique_together = ["user", "token"] - def save(self, *args, **kwargs): if self.id is None: if ( @@ -69,7 +69,24 @@ def save(self, *args, **kwargs): return super().save(*args, **kwargs) def check_token(self, token: str): - if check_password(token, self.token): + if check_password(token.lower(), self.token): self.delete() return True return False + + @classmethod + def generate_tokens(cls, count: int = max_count): + """Generates a number of tokens. + + Args: + count: The number of tokens to generate. Default to max. + + Returns: + Raw tokens that are random and unique. + """ + + tokens: t.Set[str] = set() + while len(tokens) < count: + tokens.add(get_random_string(cls.length, cls.allowed_chars)) + + return tokens diff --git a/codeforlife/user/permissions/__init__.py b/codeforlife/user/permissions/__init__.py index cb7689b5..66153640 100644 --- a/codeforlife/user/permissions/__init__.py +++ b/codeforlife/user/permissions/__init__.py @@ -1,2 +1,10 @@ -from .is_school_member import IsSchoolMember -from .is_school_teacher import IsSchoolTeacher +""" +© Ocado Group +Created on 14/12/2023 at 14:05:06(+00:00). +""" + +from .in_class import InClass +from .in_school import InSchool +from .is_independent import IsIndependent +from .is_student import IsStudent +from .is_teacher import IsTeacher diff --git a/codeforlife/user/permissions/in_class.py b/codeforlife/user/permissions/in_class.py new file mode 100644 index 00000000..c030e6d0 --- /dev/null +++ b/codeforlife/user/permissions/in_class.py @@ -0,0 +1,46 @@ +""" +© Ocado Group +Created on 12/12/2023 at 15:18:10(+00:00). +""" + +import typing as t + +from rest_framework.permissions import IsAuthenticated +from rest_framework.request import Request +from rest_framework.views import APIView + +from ..models import User + + +class InClass(IsAuthenticated): + """Request's user must be in a class.""" + + def __init__(self, class_id: t.Optional[str] = None): + """Initialize permission. + + Args: + class_id: A class' ID. If None, check if user is in any class. + Else, check if user is in the specific class. + """ + + super().__init__() + self.class_id = class_id + + def has_permission(self, request: Request, view: APIView): + user = request.user + if super().has_permission(request, view) and isinstance(user, User): + if user.teacher is not None: + classes = user.teacher.class_teacher + if self.class_id is not None: + classes = classes.filter(access_code=self.class_id) + return classes.exists() + + if user.student is not None: + if self.class_id is None: + return True + return ( + user.student.class_field is not None + and user.student.class_field.access_code == self.class_id + ) + + return False diff --git a/codeforlife/user/permissions/in_school.py b/codeforlife/user/permissions/in_school.py new file mode 100644 index 00000000..1b866d21 --- /dev/null +++ b/codeforlife/user/permissions/in_school.py @@ -0,0 +1,49 @@ +""" +© Ocado Group +Created on 12/12/2023 at 15:18:27(+00:00). +""" + +import typing as t + +from rest_framework.permissions import IsAuthenticated +from rest_framework.request import Request +from rest_framework.views import APIView + +from ..models import User + + +class InSchool(IsAuthenticated): + """Request's user must be in a school.""" + + def __init__(self, school_id: t.Optional[int] = None): + """Initialize permission. + + Args: + school_id: A school's ID. If None, check if user is in any school. + Else, check if user is in the specific school. + """ + + super().__init__() + self.school_id = school_id + + def has_permission(self, request: Request, view: APIView): + def in_school(school_id: int): + return self.school_id is None or self.school_id == school_id + + user = request.user + return ( + super().has_permission(request, view) + and isinstance(user, User) + and ( + ( + user.teacher is not None + and user.teacher.school_id is not None + and in_school(user.teacher.school_id) + ) + or ( + user.student is not None + and user.student.class_field is not None + and in_school(user.student.class_field.teacher.school_id) + ) + ) + ) diff --git a/codeforlife/user/permissions/is_independent.py b/codeforlife/user/permissions/is_independent.py new file mode 100644 index 00000000..5d0f5a8e --- /dev/null +++ b/codeforlife/user/permissions/is_independent.py @@ -0,0 +1,23 @@ +""" +© Ocado Group +Created on 12/12/2023 at 13:55:47(+00:00). +""" + +from rest_framework.permissions import IsAuthenticated +from rest_framework.request import Request +from rest_framework.views import APIView + +from ..models import User + + +class IsIndependent(IsAuthenticated): + """Request's user must be independent.""" + + def has_permission(self, request: Request, view: APIView): + user = request.user + return ( + super().has_permission(request, view) + and isinstance(user, User) + and user.teacher is None + and user.student is None + ) diff --git a/codeforlife/user/permissions/is_school_member.py b/codeforlife/user/permissions/is_school_member.py deleted file mode 100644 index 43894cdf..00000000 --- a/codeforlife/user/permissions/is_school_member.py +++ /dev/null @@ -1,18 +0,0 @@ -from rest_framework.permissions import BasePermission -from rest_framework.request import Request -from rest_framework.views import View - -from ..models import User - - -class IsSchoolMember(BasePermission): - def has_permission(self, request: Request, view: View): - user = request.user - return isinstance(user, User) and ( - (user.is_teacher and user.teacher.school is not None) - or ( - user.student is not None - # TODO: should be user.student.school is not None - and user.student.class_field is not None - ) - ) diff --git a/codeforlife/user/permissions/is_school_teacher.py b/codeforlife/user/permissions/is_school_teacher.py deleted file mode 100644 index ece94675..00000000 --- a/codeforlife/user/permissions/is_school_teacher.py +++ /dev/null @@ -1,15 +0,0 @@ -from rest_framework.permissions import BasePermission -from rest_framework.request import Request -from rest_framework.views import View - -from ..models import User - - -class IsSchoolTeacher(BasePermission): - def has_permission(self, request: Request, view: View): - user = request.user - return ( - isinstance(user, User) - and user.is_teacher - and user.teacher.school is not None - ) diff --git a/codeforlife/user/permissions/is_student.py b/codeforlife/user/permissions/is_student.py new file mode 100644 index 00000000..a4a43e9e --- /dev/null +++ b/codeforlife/user/permissions/is_student.py @@ -0,0 +1,36 @@ +""" +© Ocado Group +Created on 12/12/2023 at 13:55:40(+00:00). +""" + +import typing as t + +from rest_framework.permissions import IsAuthenticated +from rest_framework.request import Request +from rest_framework.views import APIView + +from ..models import User + + +class IsStudent(IsAuthenticated): + """Request's user must be a student.""" + + def __init__(self, student_id: t.Optional[int] = None): + """Initialize permission. + + Args: + student_id: A student's ID. If None, check if the user is any + student. Else, check if the user is the specific student. + """ + + super().__init__() + self.student_id = student_id + + def has_permission(self, request: Request, view: APIView): + user = request.user + return ( + super().has_permission(request, view) + and isinstance(user, User) + and user.student is not None + and (self.student_id is None or user.student.id == self.student_id) + ) diff --git a/codeforlife/user/permissions/is_teacher.py b/codeforlife/user/permissions/is_teacher.py new file mode 100644 index 00000000..255d9c6d --- /dev/null +++ b/codeforlife/user/permissions/is_teacher.py @@ -0,0 +1,47 @@ +""" +© Ocado Group +Created on 12/12/2023 at 13:55:22(+00:00). +""" + +import typing as t + +from rest_framework.permissions import IsAuthenticated +from rest_framework.request import Request +from rest_framework.views import APIView + +from ..models import User + + +class IsTeacher(IsAuthenticated): + """Request's user must be a teacher.""" + + def __init__( + self, + teacher_id: t.Optional[int] = None, + is_admin: t.Optional[bool] = None, + ): + """Initialize permission. + + Args: + teacher_id: A teacher's ID. If None, check if the user is any + teacher. Else, check if the user is the specific teacher. + is_admin: If the teacher is an admin. If None, don't check if the + teacher is an admin. Else, check if the teacher is (not) an + admin. + """ + + super().__init__() + self.teacher_id = teacher_id + self.is_admin = is_admin + + def has_permission(self, request: Request, view: APIView): + user = request.user + return ( + super().has_permission(request, view) + and isinstance(user, User) + and user.teacher is not None + and (self.teacher_id is None or user.teacher.id == self.teacher_id) + and ( + self.is_admin is None or user.teacher.is_admin == self.is_admin + ) + ) diff --git a/codeforlife/user/tests/auth/backends/test_otp_bypass_token.py b/codeforlife/user/tests/auth/backends/test_otp_bypass_token.py index 4c944c67..973efc5d 100644 --- a/codeforlife/user/tests/auth/backends/test_otp_bypass_token.py +++ b/codeforlife/user/tests/auth/backends/test_otp_bypass_token.py @@ -2,7 +2,6 @@ from django.test import RequestFactory, TestCase from django.utils import timezone -from django.utils.crypto import get_random_string from ....auth.backends import OtpBypassTokenBackend from ....models import ( @@ -38,9 +37,7 @@ def setUp(self): auth_factor=self.auth_factor, ) - self.tokens = [ - get_random_string(8) for _ in range(OtpBypassToken.max_count) - ] + self.tokens = OtpBypassToken.generate_tokens() self.otp_bypass_tokens = OtpBypassToken.objects.bulk_create( [ OtpBypassToken(user=self.user, token=token) @@ -52,7 +49,7 @@ def test_authenticate(self): request = self.request_factory.post("/") request.user = self.user - user = self.backend.authenticate(request, token=self.tokens[0]) + user = self.backend.authenticate(request, token=next(iter(self.tokens))) assert user == self.user assert self.otp_bypass_tokens[0].id is None diff --git a/codeforlife/user/tests/models/test_otp_bypass_token.py b/codeforlife/user/tests/models/test_otp_bypass_token.py index a6683d5c..8aeb8656 100644 --- a/codeforlife/user/tests/models/test_otp_bypass_token.py +++ b/codeforlife/user/tests/models/test_otp_bypass_token.py @@ -1,9 +1,15 @@ +""" +© Ocado Group +Created on 24/01/2024 at 16:17:22(+00:00). +""" + +from unittest.mock import call, patch + from django.contrib.auth.hashers import check_password from django.core.exceptions import ValidationError from django.test import TestCase -from django.utils.crypto import get_random_string -from ...models import OtpBypassToken, User +from ...models import OtpBypassToken, User, otp_bypass_token class TestOtpBypassToken(TestCase): @@ -11,7 +17,7 @@ def setUp(self): self.user = User.objects.get(id=2) def test_bulk_create(self): - token = get_random_string(8) + token = next(iter(OtpBypassToken.generate_tokens(1))) otp_bypass_tokens = OtpBypassToken.objects.bulk_create( [OtpBypassToken(user=self.user, token=token)] ) @@ -20,16 +26,13 @@ def test_bulk_create(self): with self.assertRaises(ValidationError): OtpBypassToken.objects.bulk_create( [ - OtpBypassToken( - user=self.user, - token=get_random_string(8), - ) - for _ in range(OtpBypassToken.max_count) + OtpBypassToken(user=self.user, token=token) + for token in OtpBypassToken.generate_tokens() ] ) def test_create(self): - token = get_random_string(8) + token = next(iter(OtpBypassToken.generate_tokens(1))) otp_bypass_token = OtpBypassToken.objects.create( user=self.user, token=token ) @@ -38,22 +41,21 @@ def test_create(self): OtpBypassToken.objects.bulk_create( [ - OtpBypassToken( - user=self.user, - token=get_random_string(8), + OtpBypassToken(user=self.user, token=token) + for token in OtpBypassToken.generate_tokens( + OtpBypassToken.max_count - 1 ) - for _ in range(OtpBypassToken.max_count - 1) ] ) with self.assertRaises(ValidationError): OtpBypassToken.objects.create( user=self.user, - token=get_random_string(8), + token=next(iter(OtpBypassToken.generate_tokens(1))), ) def test_check_token(self): - token = get_random_string(8) + token = next(iter(OtpBypassToken.generate_tokens(1))) otp_bypass_token = OtpBypassToken.objects.create( user=self.user, token=token ) @@ -65,3 +67,36 @@ def test_check_token(self): user=otp_bypass_token.user, token=otp_bypass_token.token, ) + + def test_generate_tokens(self): + """ + Generates a number of unique tokens. + """ + + count = 3 + get_random_string_side_effect = [ + "aaaaaaaa", + "aaaaaaaa", + "bbbbbbbb", + "cccccccc", + ] + + with patch.object( + otp_bypass_token, + "get_random_string", + side_effect=get_random_string_side_effect, + ) as get_random_string: + tokens = OtpBypassToken.generate_tokens(count) + assert len(tokens) == count + assert tokens == { + "aaaaaaaa", + "bbbbbbbb", + "cccccccc", + } + + get_random_string.assert_has_calls( + [ + call(OtpBypassToken.length, OtpBypassToken.allowed_chars) + for _ in range(len(get_random_string_side_effect)) + ] + ) diff --git a/codeforlife/user/tests/views/__init__.py b/codeforlife/user/tests/views/__init__.py index e69de29b..fcf35baf 100644 --- a/codeforlife/user/tests/views/__init__.py +++ b/codeforlife/user/tests/views/__init__.py @@ -0,0 +1,4 @@ +""" +© Ocado Group +Created on 24/01/2024 at 13:52:24(+00:00). +""" diff --git a/codeforlife/user/tests/views/test_klass.py b/codeforlife/user/tests/views/test_klass.py index 9db22791..eda41b04 100644 --- a/codeforlife/user/tests/views/test_klass.py +++ b/codeforlife/user/tests/views/test_klass.py @@ -5,13 +5,10 @@ from ....tests import ModelViewSetTestCase from ...models import Class -from ...serializers import ClassSerializer from ...views import ClassViewSet -class TestClassViewSet( - ModelViewSetTestCase[ClassViewSet, ClassSerializer, Class] -): +class TestClassViewSet(ModelViewSetTestCase[Class]): """ Base naming convention: test_{action} @@ -21,6 +18,7 @@ class TestClassViewSet( """ basename = "class" + model_view_set_class = ClassViewSet def _login_student(self): return self.client.login_student( diff --git a/codeforlife/user/tests/views/test_school.py b/codeforlife/user/tests/views/test_school.py index ed6c792b..788b5570 100644 --- a/codeforlife/user/tests/views/test_school.py +++ b/codeforlife/user/tests/views/test_school.py @@ -4,17 +4,14 @@ """ from rest_framework import status -from rest_framework.permissions import IsAuthenticated +from rest_framework.permissions import InSc from ....tests import ModelViewSetTestCase from ...models import Class, School, Student, Teacher, User, UserProfile -from ...serializers import SchoolSerializer from ...views import SchoolViewSet -class TestSchoolViewSet( - ModelViewSetTestCase[SchoolViewSet, SchoolSerializer, School] -): +class TestSchoolViewSet(ModelViewSetTestCase[School]): """ Base naming convention: test_{action} @@ -24,6 +21,7 @@ class TestSchoolViewSet( """ basename = "school" + model_view_set_class = SchoolViewSet # TODO: replace this setup with data fixtures. def setUp(self): @@ -196,24 +194,3 @@ def test_list__student(self): user = self._login_student() self.client.list([user.student.class_field.teacher.school]) - - # pylint: disable-next=pointless-string-statement - """ - General tests that apply to all actions. - """ - - def test_all__requires_authentication(self): - """ - User must be authenticated to call any endpoint. - """ - - assert IsAuthenticated in SchoolViewSet.permission_classes - - def test_all__only_http_get(self): - """ - These model are read-only. - """ - - assert [name.lower() for name in SchoolViewSet.http_method_names] == [ - "get" - ] diff --git a/codeforlife/user/tests/views/test_user.py b/codeforlife/user/tests/views/test_user.py index 7ca529ed..793dfc63 100644 --- a/codeforlife/user/tests/views/test_user.py +++ b/codeforlife/user/tests/views/test_user.py @@ -8,12 +8,11 @@ from ....tests import ModelViewSetTestCase from ...models import Class, School, Student, Teacher, User, UserProfile -from ...serializers import UserSerializer from ...views import UserViewSet # pylint: disable-next=too-many-ancestors,too-many-public-methods -class TestUserViewSet(ModelViewSetTestCase[UserViewSet, UserSerializer, User]): +class TestUserViewSet(ModelViewSetTestCase[User]): """ Base naming convention: test_{action} @@ -23,6 +22,7 @@ class TestUserViewSet(ModelViewSetTestCase[UserViewSet, UserSerializer, User]): """ basename = "user" + model_view_set_class = UserViewSet # TODO: replace this setup with data fixtures. def setUp(self): diff --git a/codeforlife/user/views/__init__.py b/codeforlife/user/views/__init__.py index ffcd7b4e..06f853cb 100644 --- a/codeforlife/user/views/__init__.py +++ b/codeforlife/user/views/__init__.py @@ -1,3 +1,8 @@ +""" +© Ocado Group +Created on 24/01/2024 at 13:52:02(+00:00). +""" + from .klass import ClassViewSet from .school import SchoolViewSet from .user import UserViewSet diff --git a/codeforlife/user/views/klass.py b/codeforlife/user/views/klass.py index bc863e88..e4d5ff06 100644 --- a/codeforlife/user/views/klass.py +++ b/codeforlife/user/views/klass.py @@ -1,23 +1,30 @@ -from rest_framework.permissions import IsAuthenticated -from rest_framework.viewsets import ModelViewSet +""" +© Ocado Group +Created on 24/01/2024 at 13:47:53(+00:00). +""" +import typing as t + +from ...views import ModelViewSet from ..models import Class, User -from ..permissions import IsSchoolMember +from ..permissions import InSchool from ..serializers import ClassSerializer -class ClassViewSet(ModelViewSet): +# pylint: disable-next=missing-class-docstring,too-many-ancestors +class ClassViewSet(ModelViewSet[Class]): http_method_names = ["get"] lookup_field = "access_code" serializer_class = ClassSerializer - permission_classes = [IsAuthenticated, IsSchoolMember] + permission_classes = [InSchool] + # pylint: disable-next=missing-function-docstring def get_queryset(self): - user: User = self.request.user + user = t.cast(User, self.request.user) if user.is_student: return Class.objects.filter(students=user.student) - elif user.teacher.is_admin: + if user.teacher.is_admin: # TODO: add school field to class object return Class.objects.filter(teacher__school=user.teacher.school) - else: - return Class.objects.filter(teacher=user.teacher) + + return Class.objects.filter(teacher=user.teacher) diff --git a/codeforlife/user/views/school.py b/codeforlife/user/views/school.py index 0c12d9cd..cc3b903c 100644 --- a/codeforlife/user/views/school.py +++ b/codeforlife/user/views/school.py @@ -1,22 +1,29 @@ -from rest_framework.permissions import IsAuthenticated -from rest_framework.viewsets import ModelViewSet +""" +© Ocado Group +Created on 24/01/2024 at 13:38:15(+00:00). +""" +import typing as t + +from ...views import ModelViewSet from ..models import School, User -from ..permissions import IsSchoolMember +from ..permissions import InSchool from ..serializers import SchoolSerializer -class SchoolViewSet(ModelViewSet): +# pylint: disable-next=missing-class-docstring,too-many-ancestors +class SchoolViewSet(ModelViewSet[School]): http_method_names = ["get"] serializer_class = SchoolSerializer - permission_classes = [IsAuthenticated, IsSchoolMember] + permission_classes = [InSchool] + # pylint: disable-next=missing-function-docstring def get_queryset(self): - user: User = self.request.user + user = t.cast(User, self.request.user) if user.is_student: return School.objects.filter( # TODO: should be user.student.school_id id=user.student.class_field.teacher.school_id ) - else: - return School.objects.filter(id=user.teacher.school_id) + + return School.objects.filter(id=user.teacher.school_id) diff --git a/codeforlife/user/views/user.py b/codeforlife/user/views/user.py index 9efa496c..3942b709 100644 --- a/codeforlife/user/views/user.py +++ b/codeforlife/user/views/user.py @@ -1,17 +1,25 @@ -from rest_framework.viewsets import ModelViewSet +""" +© Ocado Group +Created on 24/01/2024 at 13:12:05(+00:00). +""" +import typing as t + +from ...views import ModelViewSet from ..filters import UserFilterSet from ..models import User from ..serializers import UserSerializer -class UserViewSet(ModelViewSet): +# pylint: disable-next=missing-class-docstring,too-many-ancestors +class UserViewSet(ModelViewSet[User]): http_method_names = ["get"] serializer_class = UserSerializer filterset_class = UserFilterSet + # pylint: disable-next=missing-function-docstring def get_queryset(self): - user: User = self.request.user + user = t.cast(User, self.request.user) if user.is_student: if user.student.class_field is None: return User.objects.filter(id=user.id) diff --git a/codeforlife/views/__init__.py b/codeforlife/views/__init__.py index e69de29b..5b2d49eb 100644 --- a/codeforlife/views/__init__.py +++ b/codeforlife/views/__init__.py @@ -0,0 +1,6 @@ +""" +© Ocado Group +Created on 24/01/2024 at 13:07:38(+00:00). +""" + +from .base import ModelViewSet diff --git a/codeforlife/views/base.py b/codeforlife/views/base.py new file mode 100644 index 00000000..cac85d4a --- /dev/null +++ b/codeforlife/views/base.py @@ -0,0 +1,39 @@ +""" +© Ocado Group +Created on 24/01/2024 at 13:08:23(+00:00). +""" + +import typing as t + +from django.db.models import Model +from rest_framework.viewsets import ModelViewSet as DrfModelViewSet + +from ..serializers import ModelSerializer + +AnyModel = t.TypeVar("AnyModel", bound=Model) + + +# pylint: disable-next=too-few-public-methods +class _ModelViewSet(t.Generic[AnyModel]): + pass + + +if t.TYPE_CHECKING: + # pylint: disable-next=too-few-public-methods + class ModelViewSet( + DrfModelViewSet[AnyModel], + _ModelViewSet[AnyModel], + t.Generic[AnyModel], + ): + """Base model view set for all model view sets.""" + + serializer_class: t.Optional[t.Type[ModelSerializer[AnyModel]]] + +else: + # pylint: disable-next=missing-class-docstring,too-many-ancestors + class ModelViewSet( + DrfModelViewSet, + _ModelViewSet[AnyModel], + t.Generic[AnyModel], + ): + pass diff --git a/pyproject.toml b/pyproject.toml index 75a801c6..0d57c45b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,29 @@ upload_to_release = true [tool.black] line-length = 80 -extend-exclude = "^/codeforlife/user/migrations/" +extend-exclude = ".*/migrations/.*py" [tool.pytest.ini_options] env = ["DJANGO_SETTINGS_MODULE=manage"] + +[tool.mypy] +plugins = ["mypy_django_plugin.main", "mypy_drf_plugin.main"] +check_untyped_defs = true + +[tool.django-stubs] +django_settings_module = "manage" + +[tool.pylint.main] +init-hook = "import os; os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'manage')" + +[tool.pylint.format] +max-line-length = 80 + +[tool.pylint.MASTER] +ignore-paths = [".*/migrations/.*py"] +load-plugins = "pylint_django" + +[tool.isort] +profile = "black" +line_length = 80 +skip_glob = ["**/migrations/*.py"]