diff --git a/codeforlife/app.py b/codeforlife/app.py index 963dd5a7..b46cf1fd 100644 --- a/codeforlife/app.py +++ b/codeforlife/app.py @@ -6,6 +6,7 @@ import multiprocessing import typing as t +from django.core.management import call_command from gunicorn.app.base import BaseApplication # type: ignore[import-untyped] @@ -19,6 +20,8 @@ class StandaloneApplication(BaseApplication): """ def __init__(self, app: t.Callable): + call_command("migrate", interactive=False) + self.options = { "bind": "0.0.0.0:8080", # https://docs.gunicorn.org/en/stable/design.html#how-many-workers diff --git a/codeforlife/forms.py b/codeforlife/forms.py new file mode 100644 index 00000000..b93046a7 --- /dev/null +++ b/codeforlife/forms.py @@ -0,0 +1,81 @@ +""" +© Ocado Group +Created on 07/11/2024 at 15:08:33(+00:00). +""" + +import typing as t + +from django import forms +from django.contrib.auth import authenticate +from django.core.exceptions import ValidationError +from django.core.handlers.wsgi import WSGIRequest + +from .models import AbstractBaseUser +from .types import get_arg + +AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) + + +class BaseLoginForm(forms.Form, t.Generic[AnyAbstractBaseUser]): + """Base login form that all other login forms must inherit.""" + + user: AnyAbstractBaseUser + + @classmethod + def get_user_class(cls) -> t.Type[AnyAbstractBaseUser]: + """Get the user class.""" + return get_arg(cls, 0) + + def __init__(self, request: WSGIRequest, *args, **kwargs): + self.request = request + super().__init__(*args, **kwargs) + + def clean(self): + """Authenticates a user. + + Raises: + ValidationError: If there are form errors. + ValidationError: If the user's credentials were incorrect. + ValidationError: If the user's account is deactivated. + + Returns: + The cleaned form data. + """ + + if self.errors: + raise ValidationError( + "Found form errors. Skipping authentication.", + code="form_errors", + ) + + user = authenticate( + self.request, + **{key: self.cleaned_data[key] for key in self.fields.keys()} + ) + if user is None: + raise ValidationError( + self.get_invalid_login_error_message(), + code="invalid_login", + ) + if not isinstance(user, self.get_user_class()): + raise ValidationError( + "Incorrect user class.", + code="incorrect_user_class", + ) + if not user.is_active: + raise ValidationError( + "User is not active", + code="user_not_active", + ) + + self.user = user + + return self.cleaned_data + + def get_invalid_login_error_message(self) -> str: + """Returns the error message if the user failed to login. + + Raises: + NotImplementedError: If message is not set. + """ + raise NotImplementedError() diff --git a/codeforlife/models/__init__.py b/codeforlife/models/__init__.py index 4335ad17..a8faab8f 100644 --- a/codeforlife/models/__init__.py +++ b/codeforlife/models/__init__.py @@ -3,4 +3,7 @@ Created on 19/01/2024 at 15:20:45(+00:00). """ +from .abstract_base_session import AbstractBaseSession +from .abstract_base_user import AbstractBaseUser from .base import * +from .base_session_store import BaseSessionStore diff --git a/codeforlife/models/abstract_base_session.py b/codeforlife/models/abstract_base_session.py new file mode 100644 index 00000000..d45dc543 --- /dev/null +++ b/codeforlife/models/abstract_base_session.py @@ -0,0 +1,78 @@ +""" +© Ocado Group +Created on 06/11/2024 at 16:44:56(+00:00). +""" + +import typing as t + +from django.contrib.sessions.base_session import ( + AbstractBaseSession as _AbstractBaseSession, +) +from django.db import models +from django.utils import timezone +from django.utils.translation import gettext_lazy as _ + +from .abstract_base_user import AbstractBaseUser + +# pylint: disable=duplicate-code +if t.TYPE_CHECKING: + from django_stubs_ext.db.models import TypedModelMeta + + from .base_session_store import BaseSessionStore +else: + TypedModelMeta = object + +AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) +# pylint: enable=duplicate-code + + +class AbstractBaseSession(_AbstractBaseSession): + """ + Base session class to be inherited by all session classes. + https://docs.djangoproject.com/en/3.2/topics/http/sessions/#example + """ + + pk: str # type: ignore[assignment] + + user_id: int + + # pylint: disable-next=missing-class-docstring,too-few-public-methods + class Meta(TypedModelMeta): + abstract = True + verbose_name = _("session") + verbose_name_plural = _("sessions") + + @property + def is_expired(self): + """Whether or not this session has expired.""" + return self.expire_date < timezone.now() + + @property + def store(self): + """A store instance for this session.""" + return self.get_session_store_class()(self.session_key) + + @classmethod + def get_session_store_class(cls) -> t.Type["BaseSessionStore"]: + raise NotImplementedError + + @staticmethod + def init_user_field(user_class: t.Type[AnyAbstractBaseUser]): + """Initializes the user field that relates a session to a user. + + Example: + class Session(AbstractBaseSession): + user = AbstractBaseSession.init_user_field(User) + + Args: + user_class: The user model to associate sessions to. + + Returns: + A one-to-one field that relates to the provided user model. + """ + return models.OneToOneField( + user_class, + null=True, + blank=True, + on_delete=models.CASCADE, + ) diff --git a/codeforlife/models/abstract_base_user.py b/codeforlife/models/abstract_base_user.py new file mode 100644 index 00000000..5b2305f4 --- /dev/null +++ b/codeforlife/models/abstract_base_user.py @@ -0,0 +1,58 @@ +""" +© Ocado Group +Created on 06/11/2024 at 16:38:15(+00:00). +""" + +import typing as t +from functools import cached_property + +from django.apps import apps +from django.conf import settings +from django.contrib.auth.models import AbstractBaseUser as _AbstractBaseUser +from django.utils.translation import gettext_lazy as _ + +if t.TYPE_CHECKING: + from django_stubs_ext.db.models import TypedModelMeta + + from .abstract_base_session import AbstractBaseSession +else: + TypedModelMeta = object + + +class AbstractBaseUser(_AbstractBaseUser): + """ + Base user class to be inherited by all user classes. + https://docs.djangoproject.com/en/3.2/topics/auth/customizing/#using-a-custom-user-model-when-starting-a-project + """ + + pk: int + session: "AbstractBaseSession" + + # pylint: disable-next=missing-class-docstring,too-few-public-methods + class Meta(TypedModelMeta): + abstract = True + verbose_name = _("user") + verbose_name_plural = _("users") + + @cached_property + def _session_class(self): + return t.cast( + t.Type["AbstractBaseSession"], + apps.get_model( + app_label=( + t.cast(str, settings.SESSION_ENGINE) + .lower() + .removesuffix(".models.session") + .split(".")[-1] + ), + model_name="session", + ), + ) + + @property + def is_authenticated(self): + """A flag designating if this contributor has authenticated.""" + try: + return self.is_active and not self.session.is_expired + except self._session_class.DoesNotExist: + return False diff --git a/codeforlife/models/base.py b/codeforlife/models/base.py index 462ba671..0f974c6e 100644 --- a/codeforlife/models/base.py +++ b/codeforlife/models/base.py @@ -7,6 +7,7 @@ import typing as t +from django.db.models import Manager from django.db.models import Model as _Model if t.TYPE_CHECKING: @@ -14,16 +15,12 @@ else: TypedModelMeta = object -Id = t.TypeVar("Id") +class Model(_Model): + """Base for all models.""" -class Model(_Model, t.Generic[Id]): - """A base class for all Django models.""" + objects: Manager[t.Self] - id: Id - pk: Id - - # pylint: disable-next=missing-class-docstring,too-few-public-methods class Meta(TypedModelMeta): abstract = True diff --git a/codeforlife/models/base_session_store.py b/codeforlife/models/base_session_store.py new file mode 100644 index 00000000..b6e5e648 --- /dev/null +++ b/codeforlife/models/base_session_store.py @@ -0,0 +1,91 @@ +""" +© Ocado Group +Created on 06/11/2024 at 17:31:32(+00:00). +""" + +import typing as t + +from django.contrib.auth import SESSION_KEY +from django.contrib.sessions.backends.db import SessionStore +from django.utils import timezone + +from ..types import get_arg + +if t.TYPE_CHECKING: + from .abstract_base_session import AbstractBaseSession + from .abstract_base_user import AbstractBaseUser + + AnyAbstractBaseSession = t.TypeVar( + "AnyAbstractBaseSession", bound=AbstractBaseSession + ) + AnyAbstractBaseUser = t.TypeVar( + "AnyAbstractBaseUser", bound=AbstractBaseUser + ) +else: + AnyAbstractBaseSession = t.TypeVar("AnyAbstractBaseSession") + AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser") + + +class BaseSessionStore( + SessionStore, + t.Generic[AnyAbstractBaseSession, AnyAbstractBaseUser], +): + """ + Base session store class to be inherited by all session store classes. + https://docs.djangoproject.com/en/3.2/topics/http/sessions/#example + """ + + @classmethod + def get_model_class(cls) -> t.Type[AnyAbstractBaseSession]: + return get_arg(cls, 0) + + @classmethod + def get_user_class(cls) -> t.Type[AnyAbstractBaseUser]: + """Get the user class.""" + return get_arg(cls, 1) + + def associate_session_to_user( + self, session: AnyAbstractBaseSession, user_id: int + ): + """Associate an anon session to a user. + + Args: + session: The anon session. + user_id: The user to associate. + """ + objects = self.get_user_class().objects # type: ignore[attr-defined] + session.user = objects.get(id=user_id) # type: ignore[attr-defined] + + def create_model_instance(self, data): + try: + user_id = int(data.get(SESSION_KEY)) + except (ValueError, TypeError): + # Create an anon session. + return super().create_model_instance(data) + + model_class = self.get_model_class() + + try: + session = model_class.objects.get( + user_id=user_id, # type: ignore[misc] + ) + except model_class.DoesNotExist: + session = model_class.objects.get(session_key=self.session_key) + self.associate_session_to_user( + t.cast(AnyAbstractBaseSession, session), user_id + ) + + session.session_data = self.encode(data) + + return session + + @classmethod + def clear_expired(cls, user_id=None): + session_query = cls.get_model_class().objects.filter( + expire_date__lt=timezone.now() + ) + + if user_id is not None: + session_query = session_query.filter(user_id=user_id) + + session_query.delete() diff --git a/codeforlife/request/__init__.py b/codeforlife/request/__init__.py new file mode 100644 index 00000000..a8ed8a0c --- /dev/null +++ b/codeforlife/request/__init__.py @@ -0,0 +1,8 @@ +""" +© Ocado Group +Created on 05/11/2024 at 14:40:32(+00:00). +""" + +from .drf import BaseRequest, Request +from .http import BaseHttpRequest, HttpRequest +from .wsgi import BaseWSGIRequest, WSGIRequest diff --git a/codeforlife/request.py b/codeforlife/request/drf.py similarity index 55% rename from codeforlife/request.py rename to codeforlife/request/drf.py index 7024eef8..44b34f8d 100644 --- a/codeforlife/request.py +++ b/codeforlife/request/drf.py @@ -1,92 +1,110 @@ """ © Ocado Group -Created on 19/02/2024 at 15:28:22(+00:00). +Created on 05/11/2024 at 14:41:58(+00:00). -Override default request objects. +Custom Request which hints to our custom types. """ import typing as t -from django.contrib.auth.models import AnonymousUser -from django.core.handlers.wsgi import WSGIRequest as _WSGIRequest -from django.http import HttpRequest as _HttpRequest +from django.contrib.auth.models import AbstractBaseUser, AnonymousUser +from django.contrib.sessions.backends.db import SessionStore as DBStore from rest_framework.request import Request as _Request -from .types import JsonDict, JsonList -from .user.models import ( - AdminSchoolTeacherUser, - AnyUser, - IndependentUser, - NonAdminSchoolTeacherUser, - NonSchoolTeacherUser, - SchoolTeacherUser, - StudentUser, - TeacherUser, - User, -) -from .user.models.session import SessionStore +from ..types import JsonDict, JsonList +# pylint: disable=duplicate-code +if t.TYPE_CHECKING: + from ..user.models import User + from ..user.models.session import SessionStore -# pylint: disable-next=missing-class-docstring -class WSGIRequest(_WSGIRequest): - session: SessionStore - user: t.Union[User, AnonymousUser] + AnyUser = t.TypeVar("AnyUser", bound=User) +else: + AnyUser = t.TypeVar("AnyUser") - -# pylint: disable-next=missing-class-docstring -class HttpRequest(_HttpRequest): - session: SessionStore - user: t.Union[User, AnonymousUser] +AnyDBStore = t.TypeVar("AnyDBStore", bound=DBStore) +AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) +# pylint: enable=duplicate-code # pylint: disable-next=missing-class-docstring,abstract-method -class Request(_Request, t.Generic[AnyUser]): - session: SessionStore +class BaseRequest(_Request, t.Generic[AnyDBStore, AnyAbstractBaseUser]): data: t.Any - - def __init__(self, user_class: t.Type[AnyUser], *args, **kwargs): - super().__init__(*args, **kwargs) - self.user_class = user_class + session: AnyDBStore + user: t.Union[AnyAbstractBaseUser, AnonymousUser] @property def query_params(self) -> t.Dict[str, str]: # type: ignore[override] return super().query_params + @property + def anon_user(self): + """The anonymous user that made the request.""" + return t.cast(AnonymousUser, self.user) + + @property + def auth_user(self): + """The authenticated user that made the request.""" + return t.cast(AnyAbstractBaseUser, self.user) + + @property + def json_dict(self): + """The data as a json dictionary.""" + return t.cast(JsonDict, self.data) + + @property + def json_list(self): + """The data as a json list.""" + return t.cast(JsonList, self.data) + + +# pylint: disable-next=missing-class-docstring,abstract-method +class Request(BaseRequest["SessionStore", AnyUser], t.Generic[AnyUser]): + def __init__(self, user_class: t.Type[AnyUser], *args, **kwargs): + super().__init__(*args, **kwargs) + self.user_class = user_class + @property def user(self): return t.cast(t.Union[AnyUser, AnonymousUser], super().user) @user.setter def user(self, value): - if isinstance(value, User) and not isinstance(value, self.user_class): + # pylint: disable-next=import-outside-toplevel + from ..user.models import User + + if ( + isinstance(value, User) + and issubclass(self.user_class, User) + and not isinstance(value, self.user_class) + ): value = value.as_type(self.user_class) self._user = value self._request.user = value - @property - def anon_user(self): - """The anonymous user that made the request.""" - return t.cast(AnonymousUser, self.user) - - @property - def auth_user(self): - """The authenticated user that made the request.""" - return t.cast(AnyUser, self.user) - @property def teacher_user(self): """The authenticated teacher-user that made the request.""" + # pylint: disable-next=import-outside-toplevel + from ..user.models import TeacherUser + return self.auth_user.as_type(TeacherUser) @property def school_teacher_user(self): """The authenticated school-teacher-user that made the request.""" + # pylint: disable-next=import-outside-toplevel + from ..user.models import SchoolTeacherUser + return self.auth_user.as_type(SchoolTeacherUser) @property def admin_school_teacher_user(self): """The authenticated admin-school-teacher-user that made the request.""" + # pylint: disable-next=import-outside-toplevel + from ..user.models import AdminSchoolTeacherUser + return self.auth_user.as_type(AdminSchoolTeacherUser) @property @@ -94,29 +112,31 @@ def non_admin_school_teacher_user(self): """ The authenticated non-admin-school-teacher-user that made the request. """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import NonAdminSchoolTeacherUser + return self.auth_user.as_type(NonAdminSchoolTeacherUser) @property def non_school_teacher_user(self): """The authenticated non-school-teacher-user that made the request.""" + # pylint: disable-next=import-outside-toplevel + from ..user.models import NonSchoolTeacherUser + return self.auth_user.as_type(NonSchoolTeacherUser) @property def student_user(self): """The authenticated student-user that made the request.""" + # pylint: disable-next=import-outside-toplevel + from ..user.models import StudentUser + return self.auth_user.as_type(StudentUser) @property def indy_user(self): """The authenticated independent-user that made the request.""" - return self.auth_user.as_type(IndependentUser) - - @property - def json_dict(self): - """The data as a json dictionary.""" - return t.cast(JsonDict, self.data) + # pylint: disable-next=import-outside-toplevel + from ..user.models import IndependentUser - @property - def json_list(self): - """The data as a json list.""" - return t.cast(JsonList, self.data) + return self.auth_user.as_type(IndependentUser) diff --git a/codeforlife/request/http.py b/codeforlife/request/http.py new file mode 100644 index 00000000..a64c13e7 --- /dev/null +++ b/codeforlife/request/http.py @@ -0,0 +1,36 @@ +""" +© Ocado Group +Created on 05/11/2024 at 14:41:58(+00:00). + +Custom HttpRequest which hints to our custom types. +""" + +import typing as t + +from django.contrib.auth.models import AbstractBaseUser, AnonymousUser +from django.contrib.sessions.backends.db import SessionStore as DBStore +from django.http import HttpRequest as _HttpRequest + +# pylint: disable=duplicate-code +if t.TYPE_CHECKING: + from ..user.models import User + from ..user.models.session import SessionStore + + AnyUser = t.TypeVar("AnyUser", bound=User) +else: + AnyUser = t.TypeVar("AnyUser") + +AnyDBStore = t.TypeVar("AnyDBStore", bound=DBStore) +AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) +# pylint: enable=duplicate-code + + +# pylint: disable-next=missing-class-docstring +class BaseHttpRequest(_HttpRequest, t.Generic[AnyDBStore, AnyAbstractBaseUser]): + session: AnyDBStore + user: t.Union[AnyAbstractBaseUser, AnonymousUser] + + +# pylint: disable-next=missing-class-docstring +class HttpRequest(BaseHttpRequest["SessionStore", AnyUser], t.Generic[AnyUser]): + pass diff --git a/codeforlife/request/wsgi.py b/codeforlife/request/wsgi.py new file mode 100644 index 00000000..dd4778d3 --- /dev/null +++ b/codeforlife/request/wsgi.py @@ -0,0 +1,36 @@ +""" +© Ocado Group +Created on 05/11/2024 at 14:41:58(+00:00). + +Custom WSGIRequest which hints to our custom types. +""" + +import typing as t + +from django.contrib.auth.models import AbstractBaseUser, AnonymousUser +from django.contrib.sessions.backends.db import SessionStore as DBStore +from django.core.handlers.wsgi import WSGIRequest as _WSGIRequest + +# pylint: disable=duplicate-code +if t.TYPE_CHECKING: + from ..user.models import User + from ..user.models.session import SessionStore + + AnyUser = t.TypeVar("AnyUser", bound=User) +else: + AnyUser = t.TypeVar("AnyUser") + +AnyDBStore = t.TypeVar("AnyDBStore", bound=DBStore) +AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) +# pylint: enable=duplicate-code + + +# pylint: disable-next=missing-class-docstring +class BaseWSGIRequest(_WSGIRequest, t.Generic[AnyDBStore, AnyAbstractBaseUser]): + session: AnyDBStore + user: t.Union[AnyAbstractBaseUser, AnonymousUser] + + +# pylint: disable-next=missing-class-docstring +class WSGIRequest(BaseWSGIRequest["SessionStore", AnyUser], t.Generic[AnyUser]): + pass diff --git a/codeforlife/serializers/__init__.py b/codeforlife/serializers/__init__.py index 30c7968d..e9cec59e 100644 --- a/codeforlife/serializers/__init__.py +++ b/codeforlife/serializers/__init__.py @@ -3,5 +3,6 @@ Created on 20/01/2024 at 11:19:12(+00:00). """ -from .base import * -from .model import * +from .base import BaseSerializer +from .model import BaseModelSerializer, ModelSerializer +from .model_list import BaseModelListSerializer, ModelListSerializer diff --git a/codeforlife/serializers/base.py b/codeforlife/serializers/base.py index a49d0a2c..0ddd962e 100644 --- a/codeforlife/serializers/base.py +++ b/codeforlife/serializers/base.py @@ -10,19 +10,22 @@ from django.views import View from rest_framework.serializers import BaseSerializer as _BaseSerializer -from ..request import Request -from ..user.models import AnyUser as RequestUser +from ..request import BaseRequest + +# pylint: disable=duplicate-code +AnyBaseRequest = t.TypeVar("AnyBaseRequest", bound=BaseRequest) +# pylint: enable=duplicate-code # pylint: disable-next=abstract-method -class BaseSerializer(_BaseSerializer, t.Generic[RequestUser]): +class BaseSerializer(_BaseSerializer, t.Generic[AnyBaseRequest]): """Base serializer to be inherited by all other serializers.""" @property def request(self): """The HTTP request that triggered the view.""" - return t.cast(Request[RequestUser], self.context["request"]) + return t.cast(AnyBaseRequest, self.context["request"]) @property def view(self): diff --git a/codeforlife/serializers/model.py b/codeforlife/serializers/model.py index 1300483c..0125ffb8 100644 --- a/codeforlife/serializers/model.py +++ b/codeforlife/serializers/model.py @@ -8,38 +8,39 @@ import typing as t from django.db.models import Model -from rest_framework.serializers import ListSerializer as _ListSerializer from rest_framework.serializers import ModelSerializer as _ModelSerializer -from rest_framework.serializers import ValidationError as _ValidationError -from ..types import DataDict, OrderedDataDict -from ..user.models import AnyUser as RequestUser +from ..request import BaseRequest, Request +from ..types import DataDict from .base import BaseSerializer -AnyModel = t.TypeVar("AnyModel", bound=Model) +# pylint: disable=duplicate-code +if t.TYPE_CHECKING: + from ..user.models import User + from ..views import BaseModelViewSet, ModelViewSet + RequestUser = t.TypeVar("RequestUser", bound=User) + AnyBaseModelViewSet = t.TypeVar( + "AnyBaseModelViewSet", bound=BaseModelViewSet + ) +else: + RequestUser = t.TypeVar("RequestUser") + AnyBaseModelViewSet = t.TypeVar("AnyBaseModelViewSet") -BulkCreateDataList = t.List[DataDict] -BulkUpdateDataDict = t.Dict[t.Any, DataDict] -Data = t.Union[BulkCreateDataList, BulkUpdateDataDict] +AnyModel = t.TypeVar("AnyModel", bound=Model) +AnyBaseRequest = t.TypeVar("AnyBaseRequest", bound=BaseRequest) +# pylint: enable=duplicate-code -class ModelSerializer( - BaseSerializer[RequestUser], +class BaseModelSerializer( + BaseSerializer[AnyBaseRequest], _ModelSerializer[AnyModel], - t.Generic[RequestUser, AnyModel], + t.Generic[AnyBaseRequest, AnyBaseModelViewSet, AnyModel], ): """Base model serializer for all model serializers.""" instance: t.Optional[AnyModel] - - @property - def view(self): - # NOTE: import outside top-level to avoid circular imports. - # pylint: disable-next=import-outside-toplevel - from ..views import ModelViewSet - - return t.cast(ModelViewSet[RequestUser, AnyModel], super().view) + view: AnyBaseModelViewSet @property def non_none_instance(self): @@ -62,159 +63,12 @@ def to_representation(self, instance: AnyModel) -> DataDict: return super().to_representation(instance) -class ModelListSerializer( - BaseSerializer[RequestUser], - _ListSerializer[t.List[AnyModel]], +class ModelSerializer( + BaseModelSerializer[ + Request[RequestUser], + "ModelViewSet[RequestUser, AnyModel]", + AnyModel, + ], t.Generic[RequestUser, AnyModel], ): - """Base model list serializer for all model list serializers. - - Inherit this class if you wish to custom handle bulk create and/or update. - - class UserListSerializer(ModelListSerializer[User, User]): - def create(self, validated_data): - ... - - def update(self, instance, validated_data): - ... - - class UserSerializer(ModelSerializer[User, User]): - class Meta: - model = User - list_serializer_class = UserListSerializer - """ - - instance: t.Optional[t.List[AnyModel]] - batch_size: t.Optional[int] = None - - @property - def view(self): - # NOTE: import outside top-level to avoid circular imports. - # pylint: disable-next=import-outside-toplevel - from ..views import ModelViewSet - - return t.cast(ModelViewSet[RequestUser, AnyModel], super().view) - - @property - def non_none_instance(self): - """Casts the instance to not None.""" - return t.cast(t.List[AnyModel], self.instance) - - @classmethod - def get_model_class(cls) -> t.Type[AnyModel]: - """Get the model view set's class. - - Returns: - The model view set's class. - """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] - - def __init__(self, *args, **kwargs): - instance = args[0] if args else kwargs.pop("instance", None) - if instance is not None and not isinstance(instance, list): - instance = list(instance) - - super().__init__(instance, *args[1:], **kwargs) - - def create(self, validated_data: t.List[DataDict]) -> t.List[AnyModel]: - """Bulk create many instances of a model. - - https://www.django-rest-framework.org/api-guide/serializers/#customizing-multiple-create - - Args: - validated_data: The data used to create the models. - - Returns: - The models. - """ - model_class = self.get_model_class() - return model_class.objects.bulk_create( # type: ignore[attr-defined] - objs=[model_class(**data) for data in validated_data], - batch_size=self.batch_size, - ) - - def update( - self, - instance: t.List[AnyModel], - validated_data: t.List[DataDict], - ) -> t.List[AnyModel]: - """Bulk update many instances of a model. - - https://www.django-rest-framework.org/api-guide/serializers/#customizing-multiple-update - - Args: - instance: The models to update. - validated_data: The field-value pairs to update for each model. - - Returns: - The models. - """ - # Models and data must have equal length and be ordered the same! - for model, data in zip(instance, validated_data): - for field, value in data.items(): - setattr(model, field, value) - - model_class = self.get_model_class() - model_class.objects.bulk_update( # type: ignore[attr-defined] - objs=instance, - fields={field for data in validated_data for field in data.keys()}, - batch_size=self.batch_size, - ) - - return instance - - def validate(self, attrs: t.List[DataDict]): - # If performing a bulk create. - if self.instance is None: - if len(attrs) == 0: - raise _ValidationError( - "Nothing to create.", - code="nothing_to_create", - ) - - # Else, performing a bulk update. - else: - if len(attrs) == 0: - raise _ValidationError( - "Nothing to update.", - code="nothing_to_update", - ) - if len(attrs) != len(self.instance): - raise _ValidationError( - "Some models do not exist.", - code="models_do_not_exist", - ) - - return attrs - - def to_internal_value(self, data: Data): - # If performing a bulk create. - if self.instance is None: - data = t.cast(BulkCreateDataList, data) - - return t.cast( - t.List[OrderedDataDict], - super().to_internal_value(data), - ) - - # Else, performing a bulk update. - data = t.cast(BulkUpdateDataDict, data) - data_items = list(data.items()) - - # Models and data are required to be sorted by the lookup field. - data_items.sort(key=lambda item: item[0]) - self.instance.sort( - key=lambda model: getattr(model, self.view.lookup_field) - ) - - return t.cast( - t.List[OrderedDataDict], - super().to_internal_value([item[1] for item in data_items]), - ) - - # pylint: disable-next=useless-parent-delegation,arguments-renamed - def to_representation(self, instance: t.List[AnyModel]) -> t.List[DataDict]: - return super().to_representation(instance) + """Base model serializer for all model serializers.""" diff --git a/codeforlife/serializers/model_list.py b/codeforlife/serializers/model_list.py new file mode 100644 index 00000000..ecff4273 --- /dev/null +++ b/codeforlife/serializers/model_list.py @@ -0,0 +1,211 @@ +""" +© Ocado Group +Created on 05/11/2024 at 17:53:40(+00:00). + +Base model list serializers. +""" + +import typing as t + +from django.db.models import Model +from rest_framework.serializers import ListSerializer as _ListSerializer +from rest_framework.serializers import ValidationError as _ValidationError + +from ..request import BaseRequest, Request +from ..types import DataDict, OrderedDataDict, get_arg +from .base import BaseSerializer + +# pylint: disable=duplicate-code +if t.TYPE_CHECKING: + from ..user.models import User + from ..views import BaseModelViewSet, ModelViewSet + + RequestUser = t.TypeVar("RequestUser", bound=User) + AnyBaseModelViewSet = t.TypeVar( + "AnyBaseModelViewSet", bound=BaseModelViewSet + ) +else: + RequestUser = t.TypeVar("RequestUser") + AnyBaseModelViewSet = t.TypeVar("AnyBaseModelViewSet") + +AnyModel = t.TypeVar("AnyModel", bound=Model) +AnyBaseRequest = t.TypeVar("AnyBaseRequest", bound=BaseRequest) +# pylint: enable=duplicate-code + +BulkCreateDataList = t.List[DataDict] +BulkUpdateDataDict = t.Dict[t.Any, DataDict] +Data = t.Union[BulkCreateDataList, BulkUpdateDataDict] + + +class BaseModelListSerializer( + BaseSerializer[AnyBaseRequest], + _ListSerializer[t.List[AnyModel]], + t.Generic[AnyBaseRequest, AnyBaseModelViewSet, AnyModel], +): + """Base model list serializer for all model list serializers. + + Inherit this class if you wish to custom handle bulk create and/or update. + + class UserListSerializer(ModelListSerializer[User, User]): + def create(self, validated_data): + ... + + def update(self, instance, validated_data): + ... + + class UserSerializer(ModelSerializer[User, User]): + class Meta: + model = User + list_serializer_class = UserListSerializer + """ + + instance: t.Optional[t.List[AnyModel]] + batch_size: t.Optional[int] = None + view: AnyBaseModelViewSet + + @property + def non_none_instance(self): + """Casts the instance to not None.""" + return t.cast(t.List[AnyModel], self.instance) + + @classmethod + def get_model_class(cls) -> t.Type[AnyModel]: + """Get the model view set's class. + + Returns: + The model view set's class. + """ + return get_arg(cls, 0) + + def __init__(self, *args, **kwargs): + instance = args[0] if args else kwargs.pop("instance", None) + if instance is not None and not isinstance(instance, list): + instance = list(instance) + + super().__init__(instance, *args[1:], **kwargs) + + def create(self, validated_data: t.List[DataDict]) -> t.List[AnyModel]: + """Bulk create many instances of a model. + + https://www.django-rest-framework.org/api-guide/serializers/#customizing-multiple-create + + Args: + validated_data: The data used to create the models. + + Returns: + The models. + """ + model_class = self.get_model_class() + return model_class.objects.bulk_create( # type: ignore[attr-defined] + objs=[model_class(**data) for data in validated_data], + batch_size=self.batch_size, + ) + + def update( + self, + instance: t.List[AnyModel], + validated_data: t.List[DataDict], + ) -> t.List[AnyModel]: + """Bulk update many instances of a model. + + https://www.django-rest-framework.org/api-guide/serializers/#customizing-multiple-update + + Args: + instance: The models to update. + validated_data: The field-value pairs to update for each model. + + Returns: + The models. + """ + # Models and data must have equal length and be ordered the same! + for model, data in zip(instance, validated_data): + for field, value in data.items(): + setattr(model, field, value) + + model_class = self.get_model_class() + model_class.objects.bulk_update( # type: ignore[attr-defined] + objs=instance, + fields={field for data in validated_data for field in data.keys()}, + batch_size=self.batch_size, + ) + + return instance + + def validate(self, attrs: t.List[DataDict]): + # If performing a bulk create. + if self.instance is None: + if len(attrs) == 0: + raise _ValidationError( + "Nothing to create.", + code="nothing_to_create", + ) + + # Else, performing a bulk update. + else: + if len(attrs) == 0: + raise _ValidationError( + "Nothing to update.", + code="nothing_to_update", + ) + if len(attrs) != len(self.instance): + raise _ValidationError( + "Some models do not exist.", + code="models_do_not_exist", + ) + + return attrs + + def to_internal_value(self, data: Data): + # If performing a bulk create. + if self.instance is None: + data = t.cast(BulkCreateDataList, data) + + return t.cast( + t.List[OrderedDataDict], + super().to_internal_value(data), + ) + + # Else, performing a bulk update. + data = t.cast(BulkUpdateDataDict, data) + data_items = list(data.items()) + + # Models and data are required to be sorted by the lookup field. + data_items.sort(key=lambda item: item[0]) + self.instance.sort( + key=lambda model: getattr(model, self.view.lookup_field) + ) + + return t.cast( + t.List[OrderedDataDict], + super().to_internal_value([item[1] for item in data_items]), + ) + + # pylint: disable-next=useless-parent-delegation,arguments-renamed + def to_representation(self, instance: t.List[AnyModel]) -> t.List[DataDict]: + return super().to_representation(instance) + + +class ModelListSerializer( + BaseModelListSerializer[ + Request[RequestUser], + "ModelViewSet[RequestUser, AnyModel]", + AnyModel, + ], + t.Generic[RequestUser, AnyModel], +): + """Base model list serializer for all model list serializers. + + Inherit this class if you wish to custom handle bulk create and/or update. + + class UserListSerializer(ModelListSerializer[User, User]): + def create(self, validated_data): + ... + + def update(self, instance, validated_data): + ... + + class UserSerializer(ModelSerializer[User, User]): + class Meta: + model = User + list_serializer_class = UserListSerializer + """ diff --git a/codeforlife/tests/__init__.py b/codeforlife/tests/__init__.py index fdaee248..ad077772 100644 --- a/codeforlife/tests/__init__.py +++ b/codeforlife/tests/__init__.py @@ -5,13 +5,19 @@ Custom test cases. """ -from .api import APIClient, APITestCase -from .api_request_factory import APIRequestFactory +from .api import APITestCase, BaseAPITestCase +from .api_client import APIClient, BaseAPIClient +from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory from .cron import CronTestCase from .model import ModelTestCase -from .model_serializer import ( +from .model_list_serializer import ( + BaseModelListSerializerTestCase, ModelListSerializerTestCase, +) +from .model_serializer import ( + BaseModelSerializerTestCase, ModelSerializerTestCase, ) -from .model_view_set import ModelViewSetClient, ModelViewSetTestCase +from .model_view_set import BaseModelViewSetTestCase, ModelViewSetTestCase +from .model_view_set_client import BaseModelViewSetClient, ModelViewSetClient from .test import Client, TestCase diff --git a/codeforlife/tests/api.py b/codeforlife/tests/api.py index 89c771be..ac54e987 100644 --- a/codeforlife/tests/api.py +++ b/codeforlife/tests/api.py @@ -3,495 +3,43 @@ Created on 23/02/2024 at 08:46:27(+00:00). """ -import json import typing as t -from unittest.mock import patch -from django.utils import timezone -from rest_framework import status -from rest_framework.response import Response -from rest_framework.test import APIClient as _APIClient - -from ..types import DataDict, JsonDict -from ..user.models import AdminSchoolTeacherUser -from ..user.models import AnyUser as RequestUser -from ..user.models import ( - AuthFactor, - IndependentUser, - NonAdminSchoolTeacherUser, - NonSchoolTeacherUser, - SchoolTeacherUser, - StudentUser, - TeacherUser, - TypedUser, - User, -) -from .api_request_factory import APIRequestFactory +from ..types import get_arg +from .api_client import APIClient, BaseAPIClient from .test import TestCase -LoginUser = t.TypeVar("LoginUser", bound=User) - - -class APIClient(_APIClient, t.Generic[RequestUser]): - """Base API client to be inherited by all other API clients.""" - - _test_case: "APITestCase[RequestUser]" - - def __init__( - self, - enforce_csrf_checks: bool = False, - raise_request_exception=False, - **defaults, - ): - super().__init__( - enforce_csrf_checks, - raise_request_exception=raise_request_exception, - **defaults, - ) - - self.request_factory = APIRequestFactory( - self.get_request_user_class(), - enforce_csrf_checks, - **defaults, - ) - - @classmethod - def get_request_user_class(cls) -> t.Type[RequestUser]: - """Get the request's user class. - - Returns: - The request's user class. - """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] - - @staticmethod - def status_code_is_ok(status_code: int): - """Check if the status code is greater than or equal to 200 and less - than 300. - - Args: - status_code: The status code to check. - - Returns: - A flag designating if the status code is OK. - """ - return 200 <= status_code < 300 - - # -------------------------------------------------------------------------- - # Assert Response Helpers - # -------------------------------------------------------------------------- - - def _assert_response(self, response: Response, make_assertions: t.Callable): - if self.status_code_is_ok(response.status_code): - make_assertions() - - def _assert_response_json( - self, - response: Response, - make_assertions: t.Callable[[JsonDict], None], - ): - self._assert_response( - response, - make_assertions=lambda: make_assertions( - response.json(), # type: ignore[attr-defined] - ), - ) - - def _assert_response_json_bulk( - self, - response: Response, - make_assertions: t.Callable[[t.List[JsonDict]], None], - data: t.List[DataDict], - ): - def _make_assertions(): - response_json = response.json() # type: ignore[attr-defined] - assert isinstance(response_json, list) - assert len(response_json) == len(data) - make_assertions(response_json) +# pylint: disable=duplicate-code +if t.TYPE_CHECKING: + from ..user.models import User - self._assert_response(response, _make_assertions) + RequestUser = t.TypeVar("RequestUser", bound=User) +else: + RequestUser = t.TypeVar("RequestUser") - # -------------------------------------------------------------------------- - # Login Helpers - # -------------------------------------------------------------------------- +AnyBaseAPIClient = t.TypeVar("AnyBaseAPIClient", bound=BaseAPIClient) +# pylint: enable=duplicate-code - def _login_user_type(self, user_type: t.Type[LoginUser], **credentials): - # Logout current user (if any) before logging in next user. - self.logout() - assert super().login( - **credentials - ), f"Failed to login with credentials: {credentials}." - - user = user_type.objects.get(session=self.session.session_key) - - if user.session.auth_factors.filter( - auth_factor__type=AuthFactor.Type.OTP - ).exists(): - now = timezone.now() - otp = user.totp.at(now) - with patch.object(timezone, "now", return_value=now): - assert super().login( - request=self.request_factory.post( - user=t.cast(RequestUser, user) - ), - otp=otp, - ), f'Failed to login with OTP "{otp}" at {now}.' - - assert user.is_authenticated, "Failed to authenticate user." - - return user - - def login(self, **credentials): - """Log in a user. - - Returns: - The user. - """ - return self._login_user_type(User, **credentials) - - def login_teacher(self, email: str, password: str = "password"): - """Log in a user and assert they are a teacher. - - Args: - email: The user's email address. - password: The user's password. - - Returns: - The teacher-user. - """ - return self._login_user_type( - TeacherUser, email=email, password=password - ) - - def login_school_teacher(self, email: str, password: str = "password"): - """Log in a user and assert they are a school-teacher. - - Args: - email: The user's email address. - password: The user's password. - - Returns: - The school-teacher-user. - """ - return self._login_user_type( - SchoolTeacherUser, email=email, password=password - ) - - def login_admin_school_teacher( - self, email: str, password: str = "password" - ): - """Log in a user and assert they are an admin-school-teacher. - - Args: - email: The user's email address. - password: The user's password. - - Returns: - The admin-school-teacher-user. - """ - return self._login_user_type( - AdminSchoolTeacherUser, email=email, password=password - ) - def login_non_admin_school_teacher( - self, email: str, password: str = "password" - ): - """Log in a user and assert they are a non-admin-school-teacher. - - Args: - email: The user's email address. - password: The user's password. - - Returns: - The non-admin-school-teacher-user. - """ - return self._login_user_type( - NonAdminSchoolTeacherUser, email=email, password=password - ) - - def login_non_school_teacher(self, email: str, password: str = "password"): - """Log in a user and assert they are a non-school-teacher. - - Args: - email: The user's email address. - password: The user's password. - - Returns: - The non-school-teacher-user. - """ - return self._login_user_type( - NonSchoolTeacherUser, email=email, password=password - ) - - def login_student( - self, class_id: str, first_name: str, password: str = "password" - ): - """Log in a user and assert they are a student. - - Args: - class_id: The ID of the class the student belongs to. - first_name: The user's first name. - password: The user's password. - - Returns: - The student-user. - """ - return self._login_user_type( - StudentUser, - first_name=first_name, - password=password, - class_id=class_id, - ) - - def login_indy(self, email: str, password: str = "password"): - """Log in a user and assert they are an independent. - - Args: - email: The user's email address. - password: The user's password. - - Returns: - The independent-user. - """ - return self._login_user_type( - IndependentUser, email=email, password=password - ) - - def login_as(self, user: TypedUser, password: str = "password"): - """Log in as a user. The user instance needs to be a user proxy in order - to know which credentials are required. - - Args: - user: The user to log in as. - password: The user's password. - """ - auth_user = None - - if isinstance(user, AdminSchoolTeacherUser): - auth_user = self.login_admin_school_teacher(user.email, password) - elif isinstance(user, NonAdminSchoolTeacherUser): - auth_user = self.login_non_admin_school_teacher( - user.email, password - ) - elif isinstance(user, SchoolTeacherUser): - auth_user = self.login_school_teacher(user.email, password) - elif isinstance(user, NonSchoolTeacherUser): - auth_user = self.login_non_school_teacher(user.email, password) - elif isinstance(user, TeacherUser): - auth_user = self.login_teacher(user.email, password) - elif isinstance(user, StudentUser): - auth_user = self.login_student( - user.student.class_field.access_code, - user.first_name, - password, - ) - elif isinstance(user, IndependentUser): - auth_user = self.login_indy(user.email, password) - - assert user == auth_user - - # -------------------------------------------------------------------------- - # Request Helpers - # -------------------------------------------------------------------------- - - StatusCodeAssertion = t.Optional[t.Union[int, t.Callable[[int], bool]]] - - # pylint: disable=too-many-arguments,redefined-builtin - - def generic( - self, - method, - path, - data="", - content_type="application/json", - secure=False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - response = t.cast( - Response, - super().generic( - method, - path, - data, - content_type, - secure, - **extra, - ), - ) - - # Use a custom kwarg to handle the common case of checking the - # response's status code. - if status_code_assertion is None: - status_code_assertion = self.status_code_is_ok - elif isinstance(status_code_assertion, int): - expected_status_code = status_code_assertion - status_code_assertion = ( - # pylint: disable-next=unnecessary-lambda-assignment - lambda status_code: status_code - == expected_status_code - ) - - # pylint: disable-next=no-member - status_code = response.status_code - assert status_code_assertion( - status_code - ), f"Unexpected status code: {status_code}." + ( - "\nValidation errors: " - + json.dumps( - # pylint: disable-next=no-member - response.json(), # type: ignore[attr-defined] - indent=2, - default=str, - ) - if status_code == status.HTTP_400_BAD_REQUEST - else "" - ) - - return response - - def get( # type: ignore[override] - self, - path: str, - data: t.Any = None, - follow: bool = False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - return super().get( - path=path, - data=data, - follow=follow, - status_code_assertion=status_code_assertion, - **extra, - ) - - def post( # type: ignore[override] - self, - path: str, - data: t.Any = None, - format: t.Optional[str] = None, - content_type: t.Optional[str] = None, - follow: bool = False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - if format is None and content_type is None: - format = "json" - - return super().post( - path=path, - data=data, - format=format, - content_type=content_type, - follow=follow, - status_code_assertion=status_code_assertion, - **extra, - ) - - def put( # type: ignore[override] - self, - path: str, - data: t.Any = None, - format: t.Optional[str] = None, - content_type: t.Optional[str] = None, - follow: bool = False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - if format is None and content_type is None: - format = "json" - - return super().put( - path=path, - data=data, - format=format, - content_type=content_type, - follow=follow, - status_code_assertion=status_code_assertion, - **extra, - ) - - def patch( # type: ignore[override] - self, - path: str, - data: t.Any = None, - format: t.Optional[str] = None, - content_type: t.Optional[str] = None, - follow: bool = False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - if format is None and content_type is None: - format = "json" - - return super().patch( - path=path, - data=data, - format=format, - content_type=content_type, - follow=follow, - status_code_assertion=status_code_assertion, - **extra, - ) - - def delete( # type: ignore[override] - self, - path: str, - data: t.Any = None, - format: t.Optional[str] = None, - content_type: t.Optional[str] = None, - follow: bool = False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - if format is None and content_type is None: - format = "json" - - return super().delete( - path=path, - data=data, - format=format, - content_type=content_type, - follow=follow, - status_code_assertion=status_code_assertion, - **extra, - ) - - def options( # type: ignore[override] - self, - path: str, - data: t.Any = None, - format: t.Optional[str] = None, - content_type: t.Optional[str] = None, - follow: bool = False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - if format is None and content_type is None: - format = "json" +class BaseAPITestCase(TestCase, t.Generic[AnyBaseAPIClient]): + """Base API test case to be inherited by all other API test cases.""" - return super().options( - path=path, - data=data, - format=format, - content_type=content_type, - follow=follow, - status_code_assertion=status_code_assertion, - **extra, - ) + client: AnyBaseAPIClient + client_class: t.Type[AnyBaseAPIClient] - # pylint: enable=too-many-arguments,redefined-builtin + def _pre_setup(self): + # pylint: disable-next=protected-access + self.client_class._test_case = self + super()._pre_setup() # type: ignore[misc] -class APITestCase(TestCase, t.Generic[RequestUser]): +class APITestCase( + BaseAPITestCase[APIClient[RequestUser]], + t.Generic[RequestUser], +): """Base API test case to be inherited by all other API test cases.""" - client: APIClient[RequestUser] - client_class: t.Type[APIClient[RequestUser]] = APIClient + client_class = APIClient @classmethod def get_request_user_class(cls) -> t.Type[RequestUser]: @@ -500,10 +48,7 @@ def get_request_user_class(cls) -> t.Type[RequestUser]: Returns: The request's user class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) def _get_client_class(self): # pylint: disable-next=too-few-public-methods @@ -512,10 +57,10 @@ class _Client( self.get_request_user_class() ] ): - _test_case = self + pass return _Client def _pre_setup(self): self.client_class = self._get_client_class() - super()._pre_setup() # type: ignore[misc] + super()._pre_setup() diff --git a/codeforlife/tests/api_client.py b/codeforlife/tests/api_client.py new file mode 100644 index 00000000..3dfa4450 --- /dev/null +++ b/codeforlife/tests/api_client.py @@ -0,0 +1,549 @@ +""" +© Ocado Group +Created on 06/11/2024 at 13:35:13(+00:00). +""" + +import json +import typing as t +from unittest.mock import patch + +from django.utils import timezone +from rest_framework import status +from rest_framework.response import Response +from rest_framework.test import APIClient as _APIClient + +from ..types import DataDict, JsonDict, get_arg +from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory + +# pylint: disable=duplicate-code +if t.TYPE_CHECKING: + from ..user.models import TypedUser, User + from .api import APITestCase, BaseAPITestCase + + RequestUser = t.TypeVar("RequestUser", bound=User) + LoginUser = t.TypeVar("LoginUser", bound=User) + AnyBaseAPITestCase = t.TypeVar("AnyBaseAPITestCase", bound=BaseAPITestCase) +else: + RequestUser = t.TypeVar("RequestUser") + LoginUser = t.TypeVar("LoginUser") + AnyBaseAPITestCase = t.TypeVar("AnyBaseAPITestCase") + +AnyBaseAPIRequestFactory = t.TypeVar( + "AnyBaseAPIRequestFactory", bound=BaseAPIRequestFactory +) +# pylint: enable=duplicate-code + + +class BaseAPIClient( + _APIClient, + t.Generic[AnyBaseAPITestCase, AnyBaseAPIRequestFactory], +): + """Base API client to be inherited by all other API clients.""" + + _test_case: AnyBaseAPITestCase + + request_factory: AnyBaseAPIRequestFactory + request_factory_class: t.Type[AnyBaseAPIRequestFactory] + + def _initialize_request_factory( + self, enforce_csrf_checks: bool, **defaults + ): + return self.request_factory_class(enforce_csrf_checks, **defaults) + + def __init__( + self, + enforce_csrf_checks: bool = False, + raise_request_exception=False, + **defaults, + ): + super().__init__( + enforce_csrf_checks, + raise_request_exception=raise_request_exception, + **defaults, + ) + + self.request_factory = self._initialize_request_factory( + enforce_csrf_checks, **defaults + ) + + @staticmethod + def status_code_is_ok(status_code: int): + """Check if the status code is greater than or equal to 200 and less + than 300. + + Args: + status_code: The status code to check. + + Returns: + A flag designating if the status code is OK. + """ + return 200 <= status_code < 300 + + # -------------------------------------------------------------------------- + # Assert Response Helpers + # -------------------------------------------------------------------------- + + def _assert_response(self, response: Response, make_assertions: t.Callable): + if self.status_code_is_ok(response.status_code): + make_assertions() + + def _assert_response_json( + self, + response: Response, + make_assertions: t.Callable[[JsonDict], None], + ): + self._assert_response( + response, + make_assertions=lambda: make_assertions( + response.json(), # type: ignore[attr-defined] + ), + ) + + def _assert_response_json_bulk( + self, + response: Response, + make_assertions: t.Callable[[t.List[JsonDict]], None], + data: t.List[DataDict], + ): + def _make_assertions(): + response_json = response.json() # type: ignore[attr-defined] + assert isinstance(response_json, list) + assert len(response_json) == len(data) + make_assertions(response_json) + + self._assert_response(response, _make_assertions) + + # -------------------------------------------------------------------------- + # Request Helpers + # -------------------------------------------------------------------------- + + StatusCodeAssertion = t.Optional[t.Union[int, t.Callable[[int], bool]]] + + # pylint: disable=too-many-arguments,redefined-builtin + + def generic( + self, + method, + path, + data="", + content_type="application/json", + secure=False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + response = t.cast( + Response, + super().generic( + method, + path, + data, + content_type, + secure, + **extra, + ), + ) + + # Use a custom kwarg to handle the common case of checking the + # response's status code. + if status_code_assertion is None: + status_code_assertion = self.status_code_is_ok + elif isinstance(status_code_assertion, int): + expected_status_code = status_code_assertion + status_code_assertion = ( + # pylint: disable-next=unnecessary-lambda-assignment + lambda status_code: status_code + == expected_status_code + ) + + # pylint: disable-next=no-member + status_code = response.status_code + assert status_code_assertion( + status_code + ), f"Unexpected status code: {status_code}." + ( + "\nValidation errors: " + + json.dumps( + # pylint: disable-next=no-member + response.json(), # type: ignore[attr-defined] + indent=2, + default=str, + ) + if status_code == status.HTTP_400_BAD_REQUEST + else "" + ) + + return response + + def get( # type: ignore[override] + self, + path: str, + data: t.Any = None, + follow: bool = False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + return super().get( + path=path, + data=data, + follow=follow, + status_code_assertion=status_code_assertion, + **extra, + ) + + def post( # type: ignore[override] + self, + path: str, + data: t.Any = None, + format: t.Optional[str] = None, + content_type: t.Optional[str] = None, + follow: bool = False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + if format is None and content_type is None: + format = "json" + + return super().post( + path=path, + data=data, + format=format, + content_type=content_type, + follow=follow, + status_code_assertion=status_code_assertion, + **extra, + ) + + def put( # type: ignore[override] + self, + path: str, + data: t.Any = None, + format: t.Optional[str] = None, + content_type: t.Optional[str] = None, + follow: bool = False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + if format is None and content_type is None: + format = "json" + + return super().put( + path=path, + data=data, + format=format, + content_type=content_type, + follow=follow, + status_code_assertion=status_code_assertion, + **extra, + ) + + def patch( # type: ignore[override] + self, + path: str, + data: t.Any = None, + format: t.Optional[str] = None, + content_type: t.Optional[str] = None, + follow: bool = False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + if format is None and content_type is None: + format = "json" + + return super().patch( + path=path, + data=data, + format=format, + content_type=content_type, + follow=follow, + status_code_assertion=status_code_assertion, + **extra, + ) + + def delete( # type: ignore[override] + self, + path: str, + data: t.Any = None, + format: t.Optional[str] = None, + content_type: t.Optional[str] = None, + follow: bool = False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + if format is None and content_type is None: + format = "json" + + return super().delete( + path=path, + data=data, + format=format, + content_type=content_type, + follow=follow, + status_code_assertion=status_code_assertion, + **extra, + ) + + def options( # type: ignore[override] + self, + path: str, + data: t.Any = None, + format: t.Optional[str] = None, + content_type: t.Optional[str] = None, + follow: bool = False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + if format is None and content_type is None: + format = "json" + + return super().options( + path=path, + data=data, + format=format, + content_type=content_type, + follow=follow, + status_code_assertion=status_code_assertion, + **extra, + ) + + # pylint: enable=too-many-arguments,redefined-builtin + + +class APIClient( + BaseAPIClient["APITestCase[RequestUser]", APIRequestFactory[RequestUser]], + t.Generic[RequestUser], +): + """Base API client to be inherited by all other API clients.""" + + request_factory_class = APIRequestFactory + + def _initialize_request_factory(self, enforce_csrf_checks, **defaults): + return self.request_factory_class( + self.get_request_user_class(), + enforce_csrf_checks, + **defaults, + ) + + @classmethod + def get_request_user_class(cls) -> t.Type[RequestUser]: + """Get the request's user class. + + Returns: + The request's user class. + """ + return get_arg(cls, 0) + + # -------------------------------------------------------------------------- + # Login Helpers + # -------------------------------------------------------------------------- + + def _login_user_type(self, user_type: t.Type[LoginUser], **credentials): + # pylint: disable-next=import-outside-toplevel + from ..user.models import AuthFactor + + # Logout current user (if any) before logging in next user. + self.logout() + assert super().login( + **credentials + ), f"Failed to login with credentials: {credentials}." + + user = user_type.objects.get(session=self.session.session_key) + + if user.session.auth_factors.filter( + auth_factor__type=AuthFactor.Type.OTP + ).exists(): + now = timezone.now() + otp = user.totp.at(now) + with patch.object(timezone, "now", return_value=now): + assert super().login( + request=self.request_factory.post( + user=t.cast(RequestUser, user) + ), + otp=otp, + ), f'Failed to login with OTP "{otp}" at {now}.' + + assert user.is_authenticated, "Failed to authenticate user." + + return user + + def login(self, **credentials): + """Log in a user. + + Returns: + The user. + """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import User + + return self._login_user_type(User, **credentials) + + def login_teacher(self, email: str, password: str = "password"): + """Log in a user and assert they are a teacher. + + Args: + email: The user's email address. + password: The user's password. + + Returns: + The teacher-user. + """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import TeacherUser + + return self._login_user_type( + TeacherUser, email=email, password=password + ) + + def login_school_teacher(self, email: str, password: str = "password"): + """Log in a user and assert they are a school-teacher. + + Args: + email: The user's email address. + password: The user's password. + + Returns: + The school-teacher-user. + """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import SchoolTeacherUser + + return self._login_user_type( + SchoolTeacherUser, email=email, password=password + ) + + def login_admin_school_teacher( + self, email: str, password: str = "password" + ): + """Log in a user and assert they are an admin-school-teacher. + + Args: + email: The user's email address. + password: The user's password. + + Returns: + The admin-school-teacher-user. + """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import AdminSchoolTeacherUser + + return self._login_user_type( + AdminSchoolTeacherUser, email=email, password=password + ) + + def login_non_admin_school_teacher( + self, email: str, password: str = "password" + ): + """Log in a user and assert they are a non-admin-school-teacher. + + Args: + email: The user's email address. + password: The user's password. + + Returns: + The non-admin-school-teacher-user. + """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import NonAdminSchoolTeacherUser + + return self._login_user_type( + NonAdminSchoolTeacherUser, email=email, password=password + ) + + def login_non_school_teacher(self, email: str, password: str = "password"): + """Log in a user and assert they are a non-school-teacher. + + Args: + email: The user's email address. + password: The user's password. + + Returns: + The non-school-teacher-user. + """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import NonSchoolTeacherUser + + return self._login_user_type( + NonSchoolTeacherUser, email=email, password=password + ) + + def login_student( + self, class_id: str, first_name: str, password: str = "password" + ): + """Log in a user and assert they are a student. + + Args: + class_id: The ID of the class the student belongs to. + first_name: The user's first name. + password: The user's password. + + Returns: + The student-user. + """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import StudentUser + + return self._login_user_type( + StudentUser, + first_name=first_name, + password=password, + class_id=class_id, + ) + + def login_indy(self, email: str, password: str = "password"): + """Log in a user and assert they are an independent. + + Args: + email: The user's email address. + password: The user's password. + + Returns: + The independent-user. + """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import IndependentUser + + return self._login_user_type( + IndependentUser, email=email, password=password + ) + + def login_as(self, user: "TypedUser", password: str = "password"): + """Log in as a user. The user instance needs to be a user proxy in order + to know which credentials are required. + + Args: + user: The user to log in as. + password: The user's password. + """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import ( + AdminSchoolTeacherUser, + IndependentUser, + NonAdminSchoolTeacherUser, + NonSchoolTeacherUser, + SchoolTeacherUser, + StudentUser, + TeacherUser, + ) + + auth_user = None + + if isinstance(user, AdminSchoolTeacherUser): + auth_user = self.login_admin_school_teacher(user.email, password) + elif isinstance(user, NonAdminSchoolTeacherUser): + auth_user = self.login_non_admin_school_teacher( + user.email, password + ) + elif isinstance(user, SchoolTeacherUser): + auth_user = self.login_school_teacher(user.email, password) + elif isinstance(user, NonSchoolTeacherUser): + auth_user = self.login_non_school_teacher(user.email, password) + elif isinstance(user, TeacherUser): + auth_user = self.login_teacher(user.email, password) + elif isinstance(user, StudentUser): + auth_user = self.login_student( + user.student.class_field.access_code, + user.first_name, + password, + ) + elif isinstance(user, IndependentUser): + auth_user = self.login_indy(user.email, password) + + assert user == auth_user diff --git a/codeforlife/tests/api_request_factory.py b/codeforlife/tests/api_request_factory.py index fabb0a70..fbce7e56 100644 --- a/codeforlife/tests/api_request_factory.py +++ b/codeforlife/tests/api_request_factory.py @@ -5,6 +5,7 @@ import typing as t +from django.contrib.auth.models import AbstractBaseUser from django.core.handlers.wsgi import WSGIRequest from rest_framework.parsers import ( FileUploadParser, @@ -14,43 +15,44 @@ ) from rest_framework.test import APIRequestFactory as _APIRequestFactory -from ..request import Request -from ..user.models import AnyUser +from ..request import BaseRequest, Request +from ..types import get_arg +# pylint: disable=duplicate-code +if t.TYPE_CHECKING: + from ..user.models import User -class APIRequestFactory(_APIRequestFactory, t.Generic[AnyUser]): - """Custom API request factory that returns DRF's Request object.""" + AnyUser = t.TypeVar("AnyUser", bound=User) +else: + AnyUser = t.TypeVar("AnyUser") - def __init__(self, user_class: t.Type[AnyUser], *args, **kwargs): - super().__init__(*args, **kwargs) - self.user_class = user_class +AnyBaseRequest = t.TypeVar("AnyBaseRequest", bound=BaseRequest) +AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) +# pylint: enable=duplicate-code - @classmethod - def get_user_class(cls) -> t.Type[AnyUser]: - """Get the user class. - Returns: - The user class. - """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] - - def request(self, user: t.Optional[AnyUser] = None, **kwargs): - wsgi_request = t.cast(WSGIRequest, super().request(**kwargs)) +class BaseAPIRequestFactory( + _APIRequestFactory, t.Generic[AnyBaseRequest, AnyAbstractBaseUser] +): + """Custom API request factory that returns DRF's Request object.""" - request = Request( - self.user_class, - wsgi_request, - parsers=[ - JSONParser(), - FormParser(), - MultiPartParser(), - FileUploadParser(), - ], + def _init_request(self, wsgi_request: WSGIRequest): + return t.cast( + AnyBaseRequest, + BaseRequest( + wsgi_request, + parsers=[ + JSONParser(), + FormParser(), + MultiPartParser(), + FileUploadParser(), + ], + ), ) + def request(self, user: t.Optional[AnyAbstractBaseUser] = None, **kwargs): + wsgi_request = t.cast(WSGIRequest, super().request(**kwargs)) + request = self._init_request(wsgi_request) if user: # pylint: disable-next=attribute-defined-outside-init request.user = user @@ -65,11 +67,11 @@ def generic( data: t.Optional[str] = None, content_type: t.Optional[str] = None, secure: bool = True, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): return t.cast( - Request[AnyUser], + AnyBaseRequest, super().generic( method, path or "/", @@ -85,11 +87,11 @@ def get( # type: ignore[override] self, path: t.Optional[str] = None, data: t.Any = None, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): return t.cast( - Request[AnyUser], + AnyBaseRequest, super().get( path or "/", data, @@ -106,14 +108,14 @@ def post( # type: ignore[override] # pylint: disable-next=redefined-builtin format: t.Optional[str] = None, content_type: t.Optional[str] = None, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): if format is None and content_type is None: format = "json" return t.cast( - Request[AnyUser], + AnyBaseRequest, super().post( path or "/", data, @@ -132,14 +134,14 @@ def put( # type: ignore[override] # pylint: disable-next=redefined-builtin format: t.Optional[str] = None, content_type: t.Optional[str] = None, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): if format is None and content_type is None: format = "json" return t.cast( - Request[AnyUser], + AnyBaseRequest, super().put( path or "/", data, @@ -158,14 +160,14 @@ def patch( # type: ignore[override] # pylint: disable-next=redefined-builtin format: t.Optional[str] = None, content_type: t.Optional[str] = None, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): if format is None and content_type is None: format = "json" return t.cast( - Request[AnyUser], + AnyBaseRequest, super().patch( path or "/", data, @@ -184,14 +186,14 @@ def delete( # type: ignore[override] # pylint: disable-next=redefined-builtin format: t.Optional[str] = None, content_type: t.Optional[str] = None, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): if format is None and content_type is None: format = "json" return t.cast( - Request[AnyUser], + AnyBaseRequest, super().delete( path or "/", data, @@ -210,14 +212,14 @@ def options( # type: ignore[override] # pylint: disable-next=redefined-builtin format: t.Optional[str] = None, content_type: t.Optional[str] = None, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): if format is None and content_type is None: format = "json" return t.cast( - Request[AnyUser], + AnyBaseRequest, super().options( path or "/", data or {}, @@ -227,3 +229,35 @@ def options( # type: ignore[override] **extra, ), ) + + +class APIRequestFactory( + BaseAPIRequestFactory[Request[AnyUser], AnyUser], + t.Generic[AnyUser], +): + """Custom API request factory that returns DRF's Request object.""" + + def __init__(self, user_class: t.Type[AnyUser], *args, **kwargs): + super().__init__(*args, **kwargs) + self.user_class = user_class + + @classmethod + def get_user_class(cls) -> t.Type[AnyUser]: + """Get the user class. + + Returns: + The user class. + """ + return get_arg(cls, 0) + + def _init_request(self, wsgi_request): + return Request[AnyUser]( + self.user_class, + wsgi_request, + parsers=[ + JSONParser(), + FormParser(), + MultiPartParser(), + FileUploadParser(), + ], + ) diff --git a/codeforlife/tests/model.py b/codeforlife/tests/model.py index e04eaa00..a61edb1c 100644 --- a/codeforlife/tests/model.py +++ b/codeforlife/tests/model.py @@ -10,6 +10,7 @@ from django.db.models import Model from django.db.utils import IntegrityError +from ..types import get_arg from .test import TestCase AnyModel = t.TypeVar("AnyModel", bound=Model) @@ -25,10 +26,7 @@ def get_model_class(cls) -> t.Type[AnyModel]: Returns: The model's class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) def assert_raises_integrity_error(self, *args, **kwargs): """Assert the code block raises an integrity error. diff --git a/codeforlife/tests/model_list_serializer.py b/codeforlife/tests/model_list_serializer.py new file mode 100644 index 00000000..f25da034 --- /dev/null +++ b/codeforlife/tests/model_list_serializer.py @@ -0,0 +1,88 @@ +""" +© Ocado Group +Created on 06/11/2024 at 12:45:33(+00:00). + +Base test case for all model list serializers. +""" + +import typing as t + +from django.db.models import Model + +from ..serializers import ( + BaseModelListSerializer, + BaseModelSerializer, + ModelListSerializer, + ModelSerializer, +) +from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory +from .model_serializer import ( + BaseModelSerializerTestCase, + ModelSerializerTestCase, +) + +# pylint: disable=duplicate-code +if t.TYPE_CHECKING: + from ..user.models import User + + RequestUser = t.TypeVar("RequestUser", bound=User) +else: + RequestUser = t.TypeVar("RequestUser") + +AnyModel = t.TypeVar("AnyModel", bound=Model) +AnyBaseModelSerializer = t.TypeVar( + "AnyBaseModelSerializer", bound=BaseModelSerializer +) +AnyBaseModelListSerializer = t.TypeVar( + "AnyBaseModelListSerializer", bound=BaseModelListSerializer +) +AnyBaseAPIRequestFactory = t.TypeVar( + "AnyBaseAPIRequestFactory", bound=BaseAPIRequestFactory +) +# pylint: enable=duplicate-code + + +class BaseModelListSerializerTestCase( + BaseModelSerializerTestCase[ + AnyBaseModelSerializer, + AnyBaseAPIRequestFactory, + AnyModel, + ], + t.Generic[ + AnyBaseModelListSerializer, + AnyBaseModelSerializer, + AnyBaseAPIRequestFactory, + AnyModel, + ], +): + """Base for all model serializer test cases.""" + + model_list_serializer_class: t.Type[AnyBaseModelListSerializer] + + REQUIRED_ATTRS = { + "model_list_serializer_class", + "model_serializer_class", + "request_factory_class", + } + + def _init_model_serializer(self, *args, parent=None, **kwargs): + kwargs.setdefault("child", self.model_serializer_class()) + serializer = self.model_list_serializer_class(*args, **kwargs) + if parent: + serializer.parent = parent + + return serializer + + +# pylint: disable-next=too-many-ancestors +class ModelListSerializerTestCase( + BaseModelListSerializerTestCase[ + ModelListSerializer[RequestUser, AnyModel], + ModelSerializer[RequestUser, AnyModel], + APIRequestFactory[RequestUser], + AnyModel, + ], + ModelSerializerTestCase[RequestUser, AnyModel], + t.Generic[RequestUser, AnyModel], +): + """Base for all model serializer test cases.""" diff --git a/codeforlife/tests/model_serializer.py b/codeforlife/tests/model_serializer.py index 8c6cc7bf..1f41dec1 100644 --- a/codeforlife/tests/model_serializer.py +++ b/codeforlife/tests/model_serializer.py @@ -13,54 +13,61 @@ from django.forms.models import model_to_dict from rest_framework.serializers import BaseSerializer, ValidationError -from ..serializers import ModelListSerializer, ModelSerializer -from ..types import DataDict -from ..user.models import AnyUser as RequestUser -from .api_request_factory import APIRequestFactory +from ..serializers import ( + BaseModelListSerializer, + BaseModelSerializer, + ModelSerializer, +) +from ..types import DataDict, get_arg +from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory from .test import TestCase -AnyModel = t.TypeVar("AnyModel", bound=Model) +# pylint: disable=duplicate-code +if t.TYPE_CHECKING: + from ..user.models import User + RequestUser = t.TypeVar("RequestUser", bound=User) +else: + RequestUser = t.TypeVar("RequestUser") -class ModelSerializerTestCase(TestCase, t.Generic[RequestUser, AnyModel]): +AnyModel = t.TypeVar("AnyModel", bound=Model) +AnyBaseModelSerializer = t.TypeVar( + "AnyBaseModelSerializer", bound=BaseModelSerializer +) +AnyBaseAPIRequestFactory = t.TypeVar( + "AnyBaseAPIRequestFactory", bound=BaseAPIRequestFactory +) +# pylint: enable=duplicate-code + + +class BaseModelSerializerTestCase( + TestCase, + t.Generic[AnyBaseModelSerializer, AnyBaseAPIRequestFactory, AnyModel], +): """Base for all model serializer test cases.""" - model_serializer_class: t.Type[ModelSerializer[RequestUser, AnyModel]] + model_serializer_class: t.Type[AnyBaseModelSerializer] - request_factory: APIRequestFactory[RequestUser] - - @classmethod - def setUpClass(cls): - attr_name = "model_serializer_class" - assert hasattr(cls, attr_name), f'Attribute "{attr_name}" must be set.' + request_factory: AnyBaseAPIRequestFactory + request_factory_class: t.Type[AnyBaseAPIRequestFactory] - cls.request_factory = APIRequestFactory(cls.get_request_user_class()) - - return super().setUpClass() + REQUIRED_ATTRS: t.Set[str] = { + "model_serializer_class", + "request_factory_class", + } @classmethod - def get_request_user_class(cls) -> t.Type[AnyModel]: - """Get the model view set's class. - - Returns: - The model view set's class. - """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + def _initialize_request_factory(cls, **kwargs): + return cls.request_factory_class(**kwargs) @classmethod - def get_model_class(cls) -> t.Type[AnyModel]: - """Get the model view set's class. + def setUpClass(cls): + for attr in cls.REQUIRED_ATTRS: + assert hasattr(cls, attr), f'Attribute "{attr}" must be set.' - Returns: - The model view set's class. - """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 1 - ] + cls.request_factory = cls._initialize_request_factory() + + return super().setUpClass() # -------------------------------------------------------------------------- # Private helpers. @@ -134,7 +141,7 @@ def _assert_many( new_data: t.Optional[t.List[DataDict]], non_model_fields: t.Optional[NonModelFields], get_models: t.Callable[ - [ModelListSerializer[RequestUser, AnyModel], t.List[DataDict]], + [BaseModelListSerializer[t.Any, t.Any, AnyModel], t.List[DataDict]], t.List[AnyModel], ], *args, @@ -147,7 +154,7 @@ def _assert_many( assert len(new_data) == len(validated_data) kwargs.pop("many", None) # many must be True - serializer: ModelListSerializer[RequestUser, AnyModel] = ( + serializer: BaseModelListSerializer[t.Any, t.Any, AnyModel] = ( self._init_model_serializer(*args, **kwargs, many=True) ) @@ -371,31 +378,37 @@ def assert_new_data_is_subset_of_data(new_data: DataDict, data): self._assert_data_is_subset_of_model(data, instance) -class ModelListSerializerTestCase( - ModelSerializerTestCase[RequestUser, AnyModel], +class ModelSerializerTestCase( + BaseModelSerializerTestCase[ + ModelSerializer[RequestUser, AnyModel], + APIRequestFactory[RequestUser], + AnyModel, + ], t.Generic[RequestUser, AnyModel], ): """Base for all model serializer test cases.""" - model_list_serializer_class: t.Type[ - ModelListSerializer[RequestUser, AnyModel] - ] + request_factory_class = APIRequestFactory @classmethod - def setUpClass(cls): - attr_name = "model_list_serializer_class" - assert hasattr(cls, attr_name), f'Attribute "{attr_name}" must be set.' + def get_request_user_class(cls) -> t.Type[AnyModel]: + """Get the model view set's class. - return super().setUpClass() + Returns: + The model view set's class. + """ + return get_arg(cls, 0) - # -------------------------------------------------------------------------- - # Private helpers. - # -------------------------------------------------------------------------- + @classmethod + def get_model_class(cls) -> t.Type[AnyModel]: + """Get the model view set's class. - def _init_model_serializer(self, *args, parent=None, **kwargs): - kwargs.setdefault("child", self.model_serializer_class()) - serializer = self.model_list_serializer_class(*args, **kwargs) - if parent: - serializer.parent = parent + Returns: + The model view set's class. + """ + return get_arg(cls, 1) - return serializer + @classmethod + def _initialize_request_factory(cls, **kwargs): + kwargs["user_class"] = cls.get_request_user_class() + return super()._initialize_request_factory(**kwargs) diff --git a/codeforlife/tests/model_view_set.py b/codeforlife/tests/model_view_set.py index 855bbf02..2304a926 100644 --- a/codeforlife/tests/model_view_set.py +++ b/codeforlife/tests/model_view_set.py @@ -8,615 +8,54 @@ import typing as t from datetime import datetime +from django.contrib.auth import get_user_model from django.core.exceptions import ObjectDoesNotExist from django.db.models import Model -from django.db.models.query import QuerySet from django.urls import reverse -from django.utils.http import urlencode -from rest_framework import status -from rest_framework.response import Response +from ..models import AbstractBaseUser from ..permissions import Permission from ..serializers import BaseSerializer -from ..types import DataDict, JsonDict, KwArgs -from ..user.models import AnyUser as RequestUser -from ..views import ModelViewSet -from .api import APIClient, APITestCase +from ..types import DataDict, JsonDict, KwArgs, get_arg +from ..views import BaseModelViewSet, ModelViewSet +from .api import APITestCase, BaseAPITestCase +from .model_view_set_client import BaseModelViewSetClient, ModelViewSetClient -AnyModel = t.TypeVar("AnyModel", bound=Model) - -# pylint: disable=no-member,too-many-arguments - - -class ModelViewSetClient( - APIClient[RequestUser], t.Generic[RequestUser, AnyModel] -): - """ - An API client that helps make requests to a model view set and assert their - responses. - """ - - _test_case: "ModelViewSetTestCase[RequestUser, AnyModel]" - - @property - def _model_class(self): - """Shortcut to get model class.""" - return self._test_case.get_model_class() - - @property - def _model_view_set_class(self): - """Shortcut to get model view set class.""" - return self._test_case.model_view_set_class - - # -------------------------------------------------------------------------- - # Create (HTTP POST) - # -------------------------------------------------------------------------- - - def _assert_create(self, json_model: JsonDict, action: str): - model = self._model_class.objects.get( - **{self._model_view_set_class.lookup_field: json_model["id"]} - ) - self._test_case.assert_serialized_model_equals_json_model( - model, json_model, action, request_method="post" - ) - - def create( - self, - data: DataDict, - status_code_assertion: APIClient.StatusCodeAssertion = ( - status.HTTP_201_CREATED - ), - make_assertions: bool = True, - reverse_kwargs: t.Optional[KwArgs] = None, - **kwargs, - ): - # pylint: disable=line-too-long - """Create a model. - - Args: - data: The values for each field. - status_code_assertion: The expected status code. - make_assertions: A flag designating whether to make the default assertions. - reverse_kwargs: The kwargs for the reverse URL. - - Returns: - The HTTP response. - """ - # pylint: enable=line-too-long - - response: Response = self.post( - self._test_case.reverse_action("list", kwargs=reverse_kwargs), - data=data, - status_code_assertion=status_code_assertion, - **kwargs, - ) - - if make_assertions: - self._assert_response_json( - response, - lambda json_model: self._assert_create( - json_model, action="create" - ), - ) - - return response - - def bulk_create( - self, - data: t.List[DataDict], - status_code_assertion: APIClient.StatusCodeAssertion = ( - status.HTTP_201_CREATED - ), - make_assertions: bool = True, - reverse_kwargs: t.Optional[KwArgs] = None, - **kwargs, - ): - # pylint: disable=line-too-long - """Bulk create many instances of a model. - - Args: - data: The values for each field, for each model. - status_code_assertion: The expected status code. - make_assertions: A flag designating whether to make the default assertions. - reverse_kwargs: The kwargs for the reverse URL. - - Returns: - The HTTP response. - """ - # pylint: enable=line-too-long - - response: Response = self.post( - self._test_case.reverse_action("bulk", kwargs=reverse_kwargs), - data=data, - status_code_assertion=status_code_assertion, - **kwargs, - ) - - if make_assertions: - - def _make_assertions(json_models: t.List[JsonDict]): - for json_model in json_models: - self._assert_create(json_model, action="bulk") - - self._assert_response_json_bulk(response, _make_assertions, data) - - return response - - # -------------------------------------------------------------------------- - # Retrieve (HTTP GET) - # -------------------------------------------------------------------------- - - def retrieve( - self, - model: AnyModel, - status_code_assertion: APIClient.StatusCodeAssertion = ( - status.HTTP_200_OK - ), - make_assertions: bool = True, - reverse_kwargs: t.Optional[KwArgs] = None, - **kwargs, - ): - # pylint: disable=line-too-long - """Retrieve a model. - - Args: - model: The model to retrieve. - status_code_assertion: The expected status code. - make_assertions: A flag designating whether to make the default assertions. - reverse_kwargs: The kwargs for the reverse URL. - - Returns: - The HTTP response. - """ - # pylint: enable=line-too-long - - response: Response = self.get( - self._test_case.reverse_action( - "detail", - model, - kwargs=reverse_kwargs, - ), - status_code_assertion=status_code_assertion, - **kwargs, - ) - - if make_assertions: - self._assert_response_json( - response, - make_assertions=lambda json_model: ( - self._test_case.assert_serialized_model_equals_json_model( - model, - json_model, - action="retrieve", - request_method="get", - ) - ), - ) - - return response - - def list( - self, - models: t.Collection[AnyModel], - status_code_assertion: APIClient.StatusCodeAssertion = ( - status.HTTP_200_OK - ), - make_assertions: bool = True, - filters: t.Optional[t.Dict[str, t.Union[str, t.Iterable[str]]]] = None, - reverse_kwargs: t.Optional[KwArgs] = None, - **kwargs, - ): - # pylint: disable=line-too-long - """Retrieve a list of models. - - Args: - models: The model list to retrieve. - status_code_assertion: The expected status code. - make_assertions: A flag designating whether to make the default assertions. - filters: The filters to apply to the list. - reverse_kwargs: The kwargs for the reverse URL. - - Returns: - The HTTP response. - """ - # pylint: enable=line-too-long - - query: t.List[t.Tuple[str, str]] = [] - for key, values in (filters or {}).items(): - if isinstance(values, str): - query.append((key, values)) - else: - for value in values: - query.append((key, value)) - - response: Response = self.get( - ( - self._test_case.reverse_action("list", kwargs=reverse_kwargs) - + f"?{urlencode(query)}" - ), - status_code_assertion=status_code_assertion, - **kwargs, - ) - - if make_assertions: - - def _make_assertions(response_json: JsonDict): - json_models = t.cast(t.List[JsonDict], response_json["data"]) - assert len(models) == len(json_models) - for model, json_model in zip(models, json_models): - self._test_case.assert_serialized_model_equals_json_model( - model, json_model, action="list", request_method="get" - ) - - self._assert_response_json(response, _make_assertions) - - return response - - # -------------------------------------------------------------------------- - # Partial Update (HTTP PATCH) - # -------------------------------------------------------------------------- - - def _assert_update( - self, - model: AnyModel, - json_model: JsonDict, - action: str, - request_method: str, - partial: bool, - ): - model.refresh_from_db() - self._test_case.assert_serialized_model_equals_json_model( - model, json_model, action, request_method, contains_subset=partial - ) - - def partial_update( - self, - model: AnyModel, - data: DataDict, - status_code_assertion: APIClient.StatusCodeAssertion = ( - status.HTTP_200_OK - ), - make_assertions: bool = True, - reverse_kwargs: t.Optional[KwArgs] = None, - **kwargs, - ): - # pylint: disable=line-too-long - """Partially update a model. - - Args: - model: The model to partially update. - data: The values for each field. - status_code_assertion: The expected status code. - make_assertions: A flag designating whether to make the default assertions. - reverse_kwargs: The kwargs for the reverse URL. - - Returns: - The HTTP response. - """ - # pylint: enable=line-too-long - response: Response = self.patch( - self._test_case.reverse_action( - "detail", - model, - kwargs=reverse_kwargs, - ), - data=data, - status_code_assertion=status_code_assertion, - **kwargs, - ) - - if make_assertions: - self._assert_response_json( - response, - make_assertions=lambda json_model: self._assert_update( - model, - json_model, - action="partial_update", - request_method="patch", - partial=True, - ), - ) - - return response - - def bulk_partial_update( - self, - models: t.Union[t.List[AnyModel], QuerySet[AnyModel]], - data: t.List[DataDict], - status_code_assertion: APIClient.StatusCodeAssertion = ( - status.HTTP_200_OK - ), - make_assertions: bool = True, - reverse_kwargs: t.Optional[KwArgs] = None, - **kwargs, - ): - # pylint: disable=line-too-long - """Bulk partially update many instances of a model. - - Args: - models: The models to partially update. - data: The values for each field, for each model. - status_code_assertion: The expected status code. - make_assertions: A flag designating whether to make the default assertions. - reverse_kwargs: The kwargs for the reverse URL. - - Returns: - The HTTP response. - """ - # pylint: enable=line-too-long - if not isinstance(models, list): - models = list(models) - - response: Response = self.patch( - self._test_case.reverse_action("bulk", kwargs=reverse_kwargs), - data=data, - status_code_assertion=status_code_assertion, - **kwargs, - ) - - if make_assertions: - - def _make_assertions(json_models: t.List[JsonDict]): - models.sort( - key=lambda model: getattr( - model, self._model_view_set_class.lookup_field - ) - ) - for model, json_model in zip(models, json_models): - self._assert_update( - model, - json_model, - action="bulk", - request_method="patch", - partial=True, - ) - - self._assert_response_json_bulk(response, _make_assertions, data) - - return response - - # -------------------------------------------------------------------------- - # Update (HTTP PUT) - # -------------------------------------------------------------------------- - - def update( - self, - model: AnyModel, - action: str, - data: t.Optional[DataDict] = None, - status_code_assertion: APIClient.StatusCodeAssertion = ( - status.HTTP_200_OK - ), - make_assertions: bool = True, - reverse_kwargs: t.Optional[KwArgs] = None, - **kwargs, - ): - # pylint: disable=line-too-long - """Update a model. - - Args: - model: The model to update. - action: The name of the action. - data: The values for each field. - status_code_assertion: The expected status code. - make_assertions: A flag designating whether to make the default assertions. - reverse_kwargs: The kwargs for the reverse URL. - - Returns: - The HTTP response. - """ - # pylint: enable=line-too-long - response = self.put( - path=self._test_case.reverse_action( - action, model, kwargs=reverse_kwargs - ), - data=data, - status_code_assertion=status_code_assertion, - **kwargs, - ) - - if make_assertions: - self._assert_response_json( - response, - make_assertions=lambda json_model: self._assert_update( - model, - json_model, - action, - request_method="put", - partial=False, - ), - ) - - return response - - def bulk_update( - self, - models: t.Union[t.List[AnyModel], QuerySet[AnyModel]], - data: t.List[DataDict], - action: str, - status_code_assertion: APIClient.StatusCodeAssertion = ( - status.HTTP_200_OK - ), - make_assertions: bool = True, - reverse_kwargs: t.Optional[KwArgs] = None, - **kwargs, - ): - # pylint: disable=line-too-long - """Bulk update many instances of a model. - - Args: - models: The models to update. - data: The values for each field, for each model. - action: The name of the action. - status_code_assertion: The expected status code. - make_assertions: A flag designating whether to make the default assertions. - reverse_kwargs: The kwargs for the reverse URL. - - Returns: - The HTTP response. - """ - # pylint: enable=line-too-long - if not isinstance(models, list): - models = list(models) - - assert models - assert len(models) == len(data) - - response = self.put( - self._test_case.reverse_action(action, kwargs=reverse_kwargs), - data={ - getattr(model, self._model_view_set_class.lookup_field): _data - for model, _data in zip(models, data) - }, - status_code_assertion=status_code_assertion, - **kwargs, - ) - - if make_assertions: - - def _make_assertions(json_models: t.List[JsonDict]): - models.sort( - key=lambda model: getattr( - model, self._model_view_set_class.lookup_field - ) - ) - for model, json_model in zip(models, json_models): - self._assert_update( - model, - json_model, - action, - request_method="put", - partial=False, - ) - - self._assert_response_json_bulk(response, _make_assertions, data) - - return response - - # -------------------------------------------------------------------------- - # Destroy (HTTP DELETE) - # -------------------------------------------------------------------------- - - def _assert_destroy(self, lookup_values: t.List): - assert not self._model_class.objects.filter( - **{f"{self._model_view_set_class.lookup_field}__in": lookup_values} - ).exists() - - def destroy( - self, - model: AnyModel, - status_code_assertion: APIClient.StatusCodeAssertion = ( - status.HTTP_204_NO_CONTENT - ), - make_assertions: bool = True, - reverse_kwargs: t.Optional[KwArgs] = None, - **kwargs, - ): - # pylint: disable=line-too-long - """Destroy a model. - - Args: - model: The model to destroy. - status_code_assertion: The expected status code. - make_assertions: A flag designating whether to make the default assertions. - reverse_kwargs: The kwargs for the reverse URL. - - Returns: - The HTTP response. - """ - # pylint: enable=line-too-long - - response: Response = self.delete( - self._test_case.reverse_action( - "detail", - model, - kwargs=reverse_kwargs, - ), - status_code_assertion=status_code_assertion, - **kwargs, - ) - - if make_assertions: - self._assert_response( - response, - make_assertions=lambda: self._assert_destroy([model.pk]), - ) - - return response - - def bulk_destroy( - self, - data: t.List, - status_code_assertion: APIClient.StatusCodeAssertion = ( - status.HTTP_204_NO_CONTENT - ), - make_assertions: bool = True, - reverse_kwargs: t.Optional[KwArgs] = None, - **kwargs, - ): - # pylint: disable=line-too-long - """Bulk destroy many instances of a model. +# pylint: disable=duplicate-code +if t.TYPE_CHECKING: + from ..user.models import User - Args: - data: The primary keys of the models to lookup and destroy. - status_code_assertion: The expected status code. - make_assertions: A flag designating whether to make the default assertions. - reverse_kwargs: The kwargs for the reverse URL. - - Returns: - The HTTP response. - """ - # pylint: enable=line-too-long + RequestUser = t.TypeVar("RequestUser", bound=User) +else: + RequestUser = t.TypeVar("RequestUser") - response: Response = self.delete( - self._test_case.reverse_action("bulk", kwargs=reverse_kwargs), - data=data, - status_code_assertion=status_code_assertion, - **kwargs, - ) +AnyModel = t.TypeVar("AnyModel", bound=Model) +AnyBaseModelViewSetClient = t.TypeVar( + "AnyBaseModelViewSetClient", bound=BaseModelViewSetClient +) +AnyBaseModelViewSet = t.TypeVar("AnyBaseModelViewSet", bound=BaseModelViewSet) +# pylint: enable=duplicate-code - if make_assertions: - self._assert_response( - response, make_assertions=lambda: self._assert_destroy(data) - ) - return response +class BaseModelViewSetTestCase( + BaseAPITestCase[AnyBaseModelViewSetClient], + t.Generic[AnyBaseModelViewSet, AnyBaseModelViewSetClient, AnyModel], +): + """Base for all model view set test cases.""" - # -------------------------------------------------------------------------- - # OTHER - # -------------------------------------------------------------------------- + basename: str + model_view_set_class: t.Type[AnyBaseModelViewSet] - def cron_job(self, action: str): - """Call a CRON job. + REQUIRED_ATTRS: t.Set[str] = {"model_view_set_class", "basename"} - Args: - action: The name of the action. + @classmethod + def get_request_user_class(cls): + """Get the request's user class. Returns: - The HTTP response. + The request's user class. """ - response: Response = self.get( - self._test_case.reverse_action(action), - HTTP_X_APPENGINE_CRON="true", - ) - - return response - - -# pylint: enable=no-member - - -class ModelViewSetTestCase( - APITestCase[RequestUser], t.Generic[RequestUser, AnyModel] -): - """Base for all model view set test cases.""" - - basename: str - model_view_set_class: t.Type[ModelViewSet[RequestUser, AnyModel]] - client: ModelViewSetClient[RequestUser, AnyModel] - client_class: t.Type[ModelViewSetClient[RequestUser, AnyModel]] = ( - ModelViewSetClient - ) + return t.cast(AbstractBaseUser, get_user_model()) @classmethod def get_model_class(cls) -> t.Type[AnyModel]: @@ -626,30 +65,15 @@ def get_model_class(cls) -> t.Type[AnyModel]: The model view set's class. """ # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 1 - ] + return get_arg(cls, 2) @classmethod def setUpClass(cls): - for attr in ["model_view_set_class", "basename"]: + for attr in cls.REQUIRED_ATTRS: assert hasattr(cls, attr), f'Attribute "{attr}" must be set.' return super().setUpClass() - def _get_client_class(self): - # TODO: unpack type args in index after moving to python 3.11 - # pylint: disable-next=too-few-public-methods - class _Client( - self.client_class[ # type: ignore[misc] - self.get_request_user_class(), - self.get_model_class(), - ] - ): - _test_case = self - - return _Client - def reverse_action( self, name: str, @@ -687,6 +111,7 @@ def reverse_action( # Assertion Helpers # -------------------------------------------------------------------------- + # pylint: disable-next=too-many-arguments def assert_serialized_model_equals_json_model( self, model: AnyModel, @@ -709,11 +134,8 @@ def assert_serialized_model_equals_json_model( """ # Get the logged-in user. try: - user = t.cast( - RequestUser, - self.get_request_user_class().objects.get( - session=self.client.session.session_key - ), + user = self.get_request_user_class().objects.get( + session=self.client.session.session_key ) except ObjectDoesNotExist: user = None # NOTE: no user has logged in. @@ -721,6 +143,7 @@ def assert_serialized_model_equals_json_model( # Create an instance of the model view set and serializer. model_view_set = self.model_view_set_class( action=action.replace("-", "_"), + # pylint: disable-next=no-member request=self.client.request_factory.generic( request_method, user=user ), @@ -813,6 +236,7 @@ def assert_get_serializer_context( serializer_context: The serializer's context. action: The model view set's action. """ + # pylint: disable-next=no-member kwargs.setdefault("request", self.client.request_factory.get()) kwargs.setdefault("format_kwarg", None) model_view_set = self.model_view_set_class( @@ -823,3 +247,49 @@ def assert_get_serializer_context( actual_serializer_context | serializer_context, actual_serializer_context, ) + + +# pylint: disable-next=too-many-ancestors +class ModelViewSetTestCase( + BaseModelViewSetTestCase[ + ModelViewSet[RequestUser, AnyModel], + ModelViewSetClient[RequestUser, AnyModel], + AnyModel, + ], + APITestCase[RequestUser], + t.Generic[RequestUser, AnyModel], +): + """Base for all model view set test cases.""" + + client_class = ModelViewSetClient + + @classmethod + def get_request_user_class(cls) -> t.Type[RequestUser]: + """Get the request's user class. + + Returns: + The request's user class. + """ + return get_arg(cls, 0) + + @classmethod + def get_model_class(cls) -> t.Type[AnyModel]: + """Get the model view set's class. + + Returns: + The model view set's class. + """ + return get_arg(cls, 1) + + def _get_client_class(self): + # TODO: unpack type args in index after moving to python 3.11 + # pylint: disable-next=too-few-public-methods + class _Client( + self.client_class[ # type: ignore[misc] + self.get_request_user_class(), + self.get_model_class(), + ] + ): + _test_case = self + + return _Client diff --git a/codeforlife/tests/model_view_set_client.py b/codeforlife/tests/model_view_set_client.py new file mode 100644 index 00000000..d430f46b --- /dev/null +++ b/codeforlife/tests/model_view_set_client.py @@ -0,0 +1,638 @@ +""" +© Ocado Group +Created on 06/11/2024 at 14:13:31(+00:00). + +Base test case for all model view clients. +""" + +import typing as t + +from django.db.models import Model +from django.db.models.query import QuerySet +from django.utils.http import urlencode +from rest_framework import status +from rest_framework.response import Response + +from ..types import DataDict, JsonDict, KwArgs +from .api import APIClient, BaseAPIClient +from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory + +# pylint: disable=duplicate-code +if t.TYPE_CHECKING: + from ..user.models import User + from .model_view_set import BaseModelViewSetTestCase, ModelViewSetTestCase + + RequestUser = t.TypeVar("RequestUser", bound=User) + AnyBaseModelViewSetTestCase = t.TypeVar( + "AnyBaseModelViewSetTestCase", bound=BaseModelViewSetTestCase + ) +else: + RequestUser = t.TypeVar("RequestUser") + AnyBaseModelViewSetTestCase = t.TypeVar("AnyBaseModelViewSetTestCase") + +AnyModel = t.TypeVar("AnyModel", bound=Model) +AnyBaseAPIRequestFactory = t.TypeVar( + "AnyBaseAPIRequestFactory", bound=BaseAPIRequestFactory +) +# pylint: enable=duplicate-code + +# pylint: disable=no-member + + +# pylint: disable-next=too-many-ancestors +class BaseModelViewSetClient( + BaseAPIClient[AnyBaseModelViewSetTestCase, AnyBaseAPIRequestFactory], + t.Generic[AnyBaseModelViewSetTestCase, AnyBaseAPIRequestFactory], +): + """ + An API client that helps make requests to a model view set and assert their + responses. + """ + + @property + def _model_class(self): + """Shortcut to get model class.""" + return self._test_case.get_model_class() + + @property + def _model_view_set_class(self): + """Shortcut to get model view set class.""" + return self._test_case.model_view_set_class + + # -------------------------------------------------------------------------- + # Create (HTTP POST) + # -------------------------------------------------------------------------- + + def _assert_create(self, json_model: JsonDict, action: str): + model = self._model_class.objects.get( + **{self._model_view_set_class.lookup_field: json_model["id"]} + ) + self._test_case.assert_serialized_model_equals_json_model( + model, json_model, action, request_method="post" + ) + + def create( + self, + data: DataDict, + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_201_CREATED + ), + make_assertions: bool = True, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Create a model. + + Args: + data: The values for each field. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + + response: Response = self.post( + self._test_case.reverse_action("list", kwargs=reverse_kwargs), + data=data, + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + self._assert_response_json( + response, + lambda json_model: self._assert_create( + json_model, action="create" + ), + ) + + return response + + def bulk_create( + self, + data: t.List[DataDict], + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_201_CREATED + ), + make_assertions: bool = True, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Bulk create many instances of a model. + + Args: + data: The values for each field, for each model. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + + response: Response = self.post( + self._test_case.reverse_action("bulk", kwargs=reverse_kwargs), + data=data, + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + + def _make_assertions(json_models: t.List[JsonDict]): + for json_model in json_models: + self._assert_create(json_model, action="bulk") + + self._assert_response_json_bulk(response, _make_assertions, data) + + return response + + # -------------------------------------------------------------------------- + # Retrieve (HTTP GET) + # -------------------------------------------------------------------------- + + def retrieve( + self, + model: AnyModel, + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_200_OK + ), + make_assertions: bool = True, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Retrieve a model. + + Args: + model: The model to retrieve. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + + response: Response = self.get( + self._test_case.reverse_action( + "detail", + model, + kwargs=reverse_kwargs, + ), + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + self._assert_response_json( + response, + make_assertions=lambda json_model: ( + self._test_case.assert_serialized_model_equals_json_model( + model, + json_model, + action="retrieve", + request_method="get", + ) + ), + ) + + return response + + # pylint: disable-next=too-many-arguments + def list( + self, + models: t.Collection[AnyModel], + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_200_OK + ), + make_assertions: bool = True, + filters: t.Optional[t.Dict[str, t.Union[str, t.Iterable[str]]]] = None, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Retrieve a list of models. + + Args: + models: The model list to retrieve. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + filters: The filters to apply to the list. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + + query: t.List[t.Tuple[str, str]] = [] + for key, values in (filters or {}).items(): + if isinstance(values, str): + query.append((key, values)) + else: + for value in values: + query.append((key, value)) + + response: Response = self.get( + ( + self._test_case.reverse_action("list", kwargs=reverse_kwargs) + + f"?{urlencode(query)}" + ), + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + + def _make_assertions(response_json: JsonDict): + json_models = t.cast(t.List[JsonDict], response_json["data"]) + assert len(models) == len(json_models) + for model, json_model in zip(models, json_models): + self._test_case.assert_serialized_model_equals_json_model( + model, json_model, action="list", request_method="get" + ) + + self._assert_response_json(response, _make_assertions) + + return response + + # -------------------------------------------------------------------------- + # Partial Update (HTTP PATCH) + # -------------------------------------------------------------------------- + + # pylint: disable-next=too-many-arguments + def _assert_update( + self, + model: AnyModel, + json_model: JsonDict, + action: str, + request_method: str, + partial: bool, + ): + model.refresh_from_db() + self._test_case.assert_serialized_model_equals_json_model( + model, json_model, action, request_method, contains_subset=partial + ) + + # pylint: disable-next=too-many-arguments + def partial_update( + self, + model: AnyModel, + data: DataDict, + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_200_OK + ), + make_assertions: bool = True, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Partially update a model. + + Args: + model: The model to partially update. + data: The values for each field. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + response: Response = self.patch( + self._test_case.reverse_action( + "detail", + model, + kwargs=reverse_kwargs, + ), + data=data, + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + self._assert_response_json( + response, + make_assertions=lambda json_model: self._assert_update( + model, + json_model, + action="partial_update", + request_method="patch", + partial=True, + ), + ) + + return response + + # pylint: disable-next=too-many-arguments + def bulk_partial_update( + self, + models: t.Union[t.List[AnyModel], QuerySet[AnyModel]], + data: t.List[DataDict], + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_200_OK + ), + make_assertions: bool = True, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Bulk partially update many instances of a model. + + Args: + models: The models to partially update. + data: The values for each field, for each model. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + if not isinstance(models, list): + models = list(models) + + response: Response = self.patch( + self._test_case.reverse_action("bulk", kwargs=reverse_kwargs), + data=data, + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + + def _make_assertions(json_models: t.List[JsonDict]): + models.sort( + key=lambda model: getattr( + model, self._model_view_set_class.lookup_field + ) + ) + for model, json_model in zip(models, json_models): + self._assert_update( + model, + json_model, + action="bulk", + request_method="patch", + partial=True, + ) + + self._assert_response_json_bulk(response, _make_assertions, data) + + return response + + # -------------------------------------------------------------------------- + # Update (HTTP PUT) + # -------------------------------------------------------------------------- + + # pylint: disable-next=too-many-arguments + def update( + self, + model: AnyModel, + action: str, + data: t.Optional[DataDict] = None, + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_200_OK + ), + make_assertions: bool = True, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Update a model. + + Args: + model: The model to update. + action: The name of the action. + data: The values for each field. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + response = self.put( + path=self._test_case.reverse_action( + action, model, kwargs=reverse_kwargs + ), + data=data, + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + self._assert_response_json( + response, + make_assertions=lambda json_model: self._assert_update( + model, + json_model, + action, + request_method="put", + partial=False, + ), + ) + + return response + + # pylint: disable-next=too-many-arguments + def bulk_update( + self, + models: t.Union[t.List[AnyModel], QuerySet[AnyModel]], + data: t.List[DataDict], + action: str, + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_200_OK + ), + make_assertions: bool = True, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Bulk update many instances of a model. + + Args: + models: The models to update. + data: The values for each field, for each model. + action: The name of the action. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + if not isinstance(models, list): + models = list(models) + + assert models + assert len(models) == len(data) + + response = self.put( + self._test_case.reverse_action(action, kwargs=reverse_kwargs), + data={ + getattr(model, self._model_view_set_class.lookup_field): _data + for model, _data in zip(models, data) + }, + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + + def _make_assertions(json_models: t.List[JsonDict]): + models.sort( + key=lambda model: getattr( + model, self._model_view_set_class.lookup_field + ) + ) + for model, json_model in zip(models, json_models): + self._assert_update( + model, + json_model, + action, + request_method="put", + partial=False, + ) + + self._assert_response_json_bulk(response, _make_assertions, data) + + return response + + # -------------------------------------------------------------------------- + # Destroy (HTTP DELETE) + # -------------------------------------------------------------------------- + + def _assert_destroy(self, lookup_values: t.List): + assert not self._model_class.objects.filter( + **{f"{self._model_view_set_class.lookup_field}__in": lookup_values} + ).exists() + + def destroy( + self, + model: AnyModel, + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_204_NO_CONTENT + ), + make_assertions: bool = True, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Destroy a model. + + Args: + model: The model to destroy. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + + response: Response = self.delete( + self._test_case.reverse_action( + "detail", + model, + kwargs=reverse_kwargs, + ), + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + self._assert_response( + response, + make_assertions=lambda: self._assert_destroy([model.pk]), + ) + + return response + + def bulk_destroy( + self, + data: t.List, + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_204_NO_CONTENT + ), + make_assertions: bool = True, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Bulk destroy many instances of a model. + + Args: + data: The primary keys of the models to lookup and destroy. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + + response: Response = self.delete( + self._test_case.reverse_action("bulk", kwargs=reverse_kwargs), + data=data, + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + self._assert_response( + response, make_assertions=lambda: self._assert_destroy(data) + ) + + return response + + # -------------------------------------------------------------------------- + # OTHER + # -------------------------------------------------------------------------- + + def cron_job(self, action: str): + """Call a CRON job. + + Args: + action: The name of the action. + + Returns: + The HTTP response. + """ + response: Response = self.get( + self._test_case.reverse_action(action), + HTTP_X_APPENGINE_CRON="true", + ) + + return response + + +# pylint: enable=no-member + + +# pylint: disable-next=too-many-ancestors +class ModelViewSetClient( # type: ignore[misc] + BaseModelViewSetClient[ + "ModelViewSetTestCase[RequestUser, AnyModel]", + APIRequestFactory[RequestUser], + ], + APIClient[RequestUser], + t.Generic[RequestUser, AnyModel], +): + """ + An API client that helps make requests to a model view set and assert their + responses. + """ diff --git a/codeforlife/types.py b/codeforlife/types.py index 2d151263..68524a19 100644 --- a/codeforlife/types.py +++ b/codeforlife/types.py @@ -16,3 +16,17 @@ DataDict = t.Dict[str, t.Any] OrderedDataDict = t.OrderedDict[str, t.Any] + + +def get_arg(cls: t.Type[t.Any], index: int, orig_base: int = 0): + """Get a type arg from a class. + + Args: + cls: The class to get the type arg from. + index: The index of the type arg to get. + orig_base: The base class to get the type arg from. + + Returns: + The type arg from the class. + """ + return t.get_args(cls.__orig_bases__[orig_base])[index] diff --git a/codeforlife/user/models/session.py b/codeforlife/user/models/session.py index 25ca724f..a2819de2 100644 --- a/codeforlife/user/models/session.py +++ b/codeforlife/user/models/session.py @@ -5,13 +5,9 @@ import typing as t -from django.contrib.auth import SESSION_KEY -from django.contrib.sessions.backends.db import SessionStore as DBStore -from django.contrib.sessions.base_session import AbstractBaseSession -from django.db import models from django.db.models.query import QuerySet -from django.utils import timezone +from ...models import AbstractBaseSession, BaseSessionStore from .user import User if t.TYPE_CHECKING: # pragma: no cover @@ -26,29 +22,14 @@ class Session(AbstractBaseSession): auth_factors: QuerySet["SessionAuthFactor"] - user = models.OneToOneField( - User, - null=True, - blank=True, - on_delete=models.CASCADE, - ) - - @property - def is_expired(self): - """Whether or not this session has expired.""" - return self.expire_date < timezone.now() - - @property - def store(self): - """A store instance for this session.""" - return self.get_session_store_class()(self.session_key) + user = AbstractBaseSession.init_user_field(User) @classmethod def get_session_store_class(cls): return SessionStore -class SessionStore(DBStore): +class SessionStore(BaseSessionStore[Session, User]): """ A custom session store interface to support: 1. creating only one session per user; @@ -57,44 +38,14 @@ class SessionStore(DBStore): https://docs.djangoproject.com/en/3.2/topics/http/sessions/#example """ - @classmethod - def get_model_class(cls): - return Session - - def create_model_instance(self, data): - try: - user_id = int(data.get(SESSION_KEY)) - except (ValueError, TypeError): - # Create an anon session. - return super().create_model_instance(data) - - model_class = self.get_model_class() + def associate_session_to_user(self, session, user_id): + # pylint: disable-next=import-outside-toplevel + from .session_auth_factor import SessionAuthFactor - try: - session = model_class.objects.get(user_id=user_id) - except model_class.DoesNotExist: - # pylint: disable-next=import-outside-toplevel - from .session_auth_factor import SessionAuthFactor - - # Associate session to user. - session = model_class.objects.get(session_key=self.session_key) - session.user = User.objects.get(id=user_id) - SessionAuthFactor.objects.bulk_create( - [ - SessionAuthFactor(session=session, auth_factor=auth_factor) - for auth_factor in session.user.auth_factors.all() - ] - ) - - session.session_data = self.encode(data) - - return session - - @classmethod - def clear_expired(cls, user_id: t.Optional[int] = None): - session_query = cls.get_model_class().objects.filter( - expire_date__lt=timezone.now() + super().associate_session_to_user(session, user_id) + SessionAuthFactor.objects.bulk_create( + [ + SessionAuthFactor(session=session, auth_factor=auth_factor) + for auth_factor in session.user.auth_factors.all() + ] ) - if user_id: - session_query = session_query.filter(user_id=user_id) - session_query.delete() diff --git a/codeforlife/user/models/user.py b/codeforlife/user/models/user.py index b579cdb5..93a557d8 100644 --- a/codeforlife/user/models/user.py +++ b/codeforlife/user/models/user.py @@ -6,6 +6,7 @@ """ import string import typing as t +from datetime import datetime from common.models import TotalActivity, UserProfile @@ -18,6 +19,7 @@ from pyotp import TOTP from ... import mail +from ...models import AbstractBaseUser from .klass import Class from .school import School @@ -33,7 +35,16 @@ TypedModelMeta = object -class User(_User): +# TODO: remove in new schema +class _AbstractBaseUser(AbstractBaseUser): + password: str = None # type: ignore[assignment] + last_login: datetime = None # type: ignore[assignment] + + class Meta(TypedModelMeta): + abstract = True + + +class User(_AbstractBaseUser, _User): """A proxy to Django's user class.""" _password: t.Optional[str] @@ -51,16 +62,12 @@ class Meta(TypedModelMeta): @property def is_authenticated(self): - """ - Check if the user has any pending auth factors. - """ - # pylint: disable-next=import-outside-toplevel - from .session import Session - - try: - return self.is_active and not self.session.auth_factors.exists() - except Session.DoesNotExist: - return False + return ( + not self.session.auth_factors.exists() + and self.userprofile.is_verified + if super().is_authenticated + else False + ) @property def student(self) -> t.Optional["Student"]: @@ -181,6 +188,7 @@ def filter_users(self, queryset: QuerySet[User]): return queryset.exclude(email__isnull=True).exclude(email="") +# pylint: disable-next=too-many-ancestors class ContactableUser(User): """A user that can be contacted.""" @@ -264,6 +272,7 @@ def get_queryset(self): return super().get_queryset().prefetch_related("new_teacher") +# pylint: disable-next=too-many-ancestors class TeacherUser(ContactableUser): """A user that is a teacher.""" @@ -477,6 +486,7 @@ def get_queryset(self): return super().get_queryset().prefetch_related("new_student") +# pylint: disable-next=too-many-ancestors class StudentUser(User): """A user that is a student.""" @@ -591,6 +601,7 @@ def create_user( # type: ignore[override] return user +# pylint: disable-next=too-many-ancestors class IndependentUser(ContactableUser): """A user that is an independent learner.""" diff --git a/codeforlife/user/serializers/user_test.py b/codeforlife/user/serializers/user_test.py index 415a1462..b07240a1 100644 --- a/codeforlife/user/serializers/user_test.py +++ b/codeforlife/user/serializers/user_test.py @@ -8,7 +8,7 @@ from .user import UserSerializer -# pylint: disable-next=missing-class-docstring +# pylint: disable-next=missing-class-docstring,too-many-ancestors class TestUserSerializer(ModelSerializerTestCase[User, User]): model_serializer_class = UserSerializer diff --git a/codeforlife/user/views/klass.py b/codeforlife/user/views/klass.py index 3da9c783..2a4c7b07 100644 --- a/codeforlife/user/views/klass.py +++ b/codeforlife/user/views/klass.py @@ -19,6 +19,7 @@ class ClassViewSet(ModelViewSet[RequestUser, Class]): serializer_class = ClassSerializer filterset_class = ClassFilterSet + # pylint: disable-next=missing-function-docstring def get_permissions(self): # Only school-teachers can list classes. if self.action == "list": diff --git a/codeforlife/user/views/user.py b/codeforlife/user/views/user.py index 90aa9c8c..6b9e08cc 100644 --- a/codeforlife/user/views/user.py +++ b/codeforlife/user/views/user.py @@ -21,6 +21,7 @@ class UserViewSet(ModelViewSet[RequestUser, User]): serializer_class = UserSerializer[User] filterset_class = UserFilterSet + # pylint: disable-next=missing-function-docstring def get_queryset( self, user_class: t.Type[AnyUser] = User, # type: ignore[assignment] @@ -67,6 +68,7 @@ def get_queryset( return queryset.filter(pk=user.pk) + # pylint: disable-next=missing-function-docstring def get_bulk_queryset( # pragma: no cover self, lookup_values: t.Collection, diff --git a/codeforlife/views/__init__.py b/codeforlife/views/__init__.py index 1822c7e9..3abde6bf 100644 --- a/codeforlife/views/__init__.py +++ b/codeforlife/views/__init__.py @@ -3,7 +3,8 @@ Created on 24/01/2024 at 13:07:38(+00:00). """ -from .api import APIView +from .api import APIView, BaseAPIView +from .base_login import BaseLoginView from .common import CsrfCookieView, LogoutView from .decorators import action, cron_job -from .model import ModelViewSet +from .model import BaseModelViewSet, ModelViewSet diff --git a/codeforlife/views/api.py b/codeforlife/views/api.py index af7ad2db..1b95ead1 100644 --- a/codeforlife/views/api.py +++ b/codeforlife/views/api.py @@ -5,15 +5,49 @@ import typing as t +from django.http import HttpRequest from rest_framework.views import APIView as _APIView -from ..request import Request -from ..user.models import AnyUser as RequestUser +from ..request import BaseRequest, Request +from ..types import get_arg + +# pylint: disable=duplicate-code +if t.TYPE_CHECKING: + from ..user.models import User + + RequestUser = t.TypeVar("RequestUser", bound=User) +else: + RequestUser = t.TypeVar("RequestUser") + +AnyBaseRequest = t.TypeVar("AnyBaseRequest", bound=BaseRequest) + +# pylint: enable=duplicate-code + + +# pylint: disable-next=missing-class-docstring +class BaseAPIView(_APIView, t.Generic[AnyBaseRequest]): + request: AnyBaseRequest + request_class: t.Type[AnyBaseRequest] + + def _initialize_request(self, request: HttpRequest, **kwargs): + kwargs["request"] = request + kwargs.setdefault("parsers", self.get_parsers()) + kwargs.setdefault("authenticators", self.get_authenticators()) + kwargs.setdefault("negotiator", self.get_content_negotiator()) + kwargs.setdefault("parser_context", self.get_parser_context(request)) + + return self.request_class(**kwargs) + + def initialize_request(self, request, *args, **kwargs): + # NOTE: Call to super has side effects and is required. + super().initialize_request(request, *args, **kwargs) + + return self._initialize_request(request) # pylint: disable-next=missing-class-docstring -class APIView(_APIView, t.Generic[RequestUser]): - request: Request[RequestUser] +class APIView(BaseAPIView[Request[RequestUser]], t.Generic[RequestUser]): + request_class = Request @classmethod def get_request_user_class(cls) -> t.Type[RequestUser]: @@ -22,20 +56,8 @@ def get_request_user_class(cls) -> t.Type[RequestUser]: Returns: The request's user class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] - - def initialize_request(self, request, *args, **kwargs): - # NOTE: Call to super has side effects and is required. - super().initialize_request(request, *args, **kwargs) + return get_arg(cls, 0) - return Request( - user_class=self.get_request_user_class(), - request=request, - parsers=self.get_parsers(), - authenticators=self.get_authenticators(), - negotiator=self.get_content_negotiator(), - parser_context=self.get_parser_context(request), - ) + def _initialize_request(self, request, **kwargs): + kwargs["user_class"] = self.get_request_user_class() + return super()._initialize_request(request, **kwargs) diff --git a/codeforlife/views/base_login.py b/codeforlife/views/base_login.py new file mode 100644 index 00000000..76eb1a49 --- /dev/null +++ b/codeforlife/views/base_login.py @@ -0,0 +1,106 @@ +""" +© Ocado Group +Created on 07/11/2024 at 14:58:38(+00:00). +""" + +import json +import typing as t +from urllib.parse import quote_plus + +from django.conf import settings +from django.contrib.auth import login +from django.contrib.auth.views import LoginView +from django.http import JsonResponse +from rest_framework import status + +from ..forms import BaseLoginForm +from ..models import AbstractBaseUser +from ..request import BaseHttpRequest +from ..types import JsonDict + +AnyBaseHttpRequest = t.TypeVar("AnyBaseHttpRequest", bound=BaseHttpRequest) +AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) + + +class BaseLoginView( + LoginView, + t.Generic[AnyBaseHttpRequest, AnyAbstractBaseUser], +): + """ + Extends Django's login view to allow a user to log in using one of the + approved forms. + + WARNING: It's critical that to inherit Django's login view as it implements + industry standard security measures that a login view should have. + """ + + request: AnyBaseHttpRequest + + def get_form_kwargs(self): + form_kwargs = super().get_form_kwargs() + form_kwargs["data"] = json.loads(self.request.body) + + return form_kwargs + + def get_session_metadata(self, user: AnyAbstractBaseUser) -> JsonDict: + """Get the session's metadata. + + Args: + user: The user the session is for. + + Raises: + NotImplementedError: If this method is not implemented. + + Returns: + A JSON-serializable dict containing the session's metadata. + """ + raise NotImplementedError + + def form_valid( + self, form: BaseLoginForm[AnyAbstractBaseUser] # type: ignore + ): + user = form.user + + # Clear expired sessions. + self.request.session.clear_expired(user_id=user.pk) + + # Create session (without data). + login(self.request, user) + + # Save session (with data). + self.request.session.save() + + # Get session metadata. + session_metadata = self.get_session_metadata(user) + + # Return session metadata in response and a non-HTTP-only cookie. + response = JsonResponse(session_metadata) + response.set_cookie( + key=settings.SESSION_METADATA_COOKIE_NAME, + value=quote_plus( + json.dumps( + session_metadata, + separators=(",", ":"), + indent=None, + ) + ), + max_age=( + None + if settings.SESSION_EXPIRE_AT_BROWSER_CLOSE + else settings.SESSION_COOKIE_AGE + ), + secure=settings.SESSION_COOKIE_SECURE, + samesite=t.cast( + t.Optional[t.Literal["Lax", "Strict", "None", False]], + settings.SESSION_COOKIE_SAMESITE, + ), + domain=settings.SESSION_COOKIE_DOMAIN, + httponly=False, + ) + + return response + + def form_invalid( + self, form: BaseLoginForm[AnyAbstractBaseUser] # type: ignore + ): + return JsonResponse(form.errors, status=status.HTTP_400_BAD_REQUEST) diff --git a/codeforlife/views/model.py b/codeforlife/views/model.py index 14ed52fa..1ac326e5 100644 --- a/codeforlife/views/model.py +++ b/codeforlife/views/model.py @@ -14,16 +14,26 @@ from rest_framework.viewsets import ModelViewSet as DrfModelViewSet from ..permissions import Permission -from ..request import Request -from ..types import KwArgs -from ..user.models import AnyUser as RequestUser -from .api import APIView +from ..request import BaseRequest, Request +from ..types import KwArgs, get_arg +from .api import APIView, BaseAPIView from .decorators import action AnyModel = t.TypeVar("AnyModel", bound=Model) +# pylint: disable=duplicate-code if t.TYPE_CHECKING: # pragma: no cover - from ..serializers import ModelListSerializer, ModelSerializer + from ..serializers import ( + BaseModelSerializer, + ModelListSerializer, + ModelSerializer, + ) + from ..user.models import User + + RequestUser = t.TypeVar("RequestUser", bound=User) + AnyBaseModelSerializer = t.TypeVar( + "AnyBaseModelSerializer", bound=BaseModelSerializer + ) # NOTE: This raises an error during runtime. # pylint: disable-next=too-few-public-methods @@ -31,22 +41,28 @@ class _ModelViewSet(DrfModelViewSet[AnyModel], t.Generic[AnyModel]): pass else: + RequestUser = t.TypeVar("RequestUser") + AnyBaseModelSerializer = t.TypeVar("AnyBaseModelSerializer") + # pylint: disable-next=too-many-ancestors class _ModelViewSet(DrfModelViewSet, t.Generic[AnyModel]): pass +AnyBaseRequest = t.TypeVar("AnyBaseRequest", bound=BaseRequest) + +# pylint: enable=duplicate-code + + # pylint: disable-next=too-many-ancestors -class ModelViewSet( - APIView[RequestUser], +class BaseModelViewSet( + BaseAPIView[AnyBaseRequest], _ModelViewSet[AnyModel], - t.Generic[RequestUser, AnyModel], + t.Generic[AnyBaseRequest, AnyBaseModelSerializer, AnyModel], ): """Base model view set for all model view sets.""" - serializer_class: t.Optional[ - t.Type["ModelSerializer[RequestUser, AnyModel]"] - ] + serializer_class: t.Optional[t.Type[AnyBaseModelSerializer]] @classmethod def get_model_class(cls) -> t.Type[AnyModel]: @@ -55,10 +71,7 @@ def get_model_class(cls) -> t.Type[AnyModel]: Returns: The model view set's class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) @cached_property def lookup_field_name(self): @@ -109,47 +122,52 @@ class _ModelListSerializer( return serializer - # -------------------------------------------------------------------------- - # View Set Actions - # -------------------------------------------------------------------------- - # pylint: disable=useless-parent-delegation def destroy( # type: ignore[override] # pragma: no cover - self, request: Request[RequestUser], *args, **kwargs + self, request: AnyBaseRequest, *args, **kwargs ): return super().destroy(request, *args, **kwargs) def create( # type: ignore[override] # pragma: no cover - self, request: Request[RequestUser], *args, **kwargs + self, request: AnyBaseRequest, *args, **kwargs ): return super().create(request, *args, **kwargs) def list( # type: ignore[override] # pragma: no cover - self, request: Request[RequestUser], *args, **kwargs + self, request: AnyBaseRequest, *args, **kwargs ): return super().list(request, *args, **kwargs) def retrieve( # type: ignore[override] # pragma: no cover - self, request: Request[RequestUser], *args, **kwargs + self, request: AnyBaseRequest, *args, **kwargs ): return super().retrieve(request, *args, **kwargs) def update( # type: ignore[override] # pragma: no cover - self, request: Request[RequestUser], *args, **kwargs + self, request: AnyBaseRequest, *args, **kwargs ): return super().update(request, *args, **kwargs) def partial_update( # type: ignore[override] # pragma: no cover - self, request: Request[RequestUser], *args, **kwargs + self, request: AnyBaseRequest, *args, **kwargs ): return super().partial_update(request, *args, **kwargs) # pylint: enable=useless-parent-delegation - # -------------------------------------------------------------------------- - # Bulk Actions - # -------------------------------------------------------------------------- + +# pylint: disable-next=too-many-ancestors +class ModelViewSet( + BaseModelViewSet[ + Request[RequestUser], + "ModelSerializer[RequestUser, AnyModel]", + AnyModel, + ], + APIView[RequestUser], + t.Generic[RequestUser, AnyModel], +): + """Base model view set for all model view sets.""" def get_bulk_queryset(self, lookup_values: t.Collection): """Get the queryset for a bulk action.