From 75456ff48380dce037aef43c6c00608af3fb9a01 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Mon, 4 Nov 2024 16:56:01 +0000 Subject: [PATCH 01/56] avoid importing models --- codeforlife/request.py | 57 ++++++++++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/codeforlife/request.py b/codeforlife/request.py index 7024eef8..6c8e342b 100644 --- a/codeforlife/request.py +++ b/codeforlife/request.py @@ -13,38 +13,32 @@ from rest_framework.request import Request as _Request from .types import JsonDict, JsonList -from .user.models import ( - AdminSchoolTeacherUser, - AnyUser, - IndependentUser, - NonAdminSchoolTeacherUser, - NonSchoolTeacherUser, - SchoolTeacherUser, - StudentUser, - TeacherUser, - User, -) -from .user.models.session import SessionStore + +if t.TYPE_CHECKING: + from .user.models import User + from .user.models.session import SessionStore + + AnyUser = t.TypeVar("AnyUser", bound=User) # pylint: disable-next=missing-class-docstring class WSGIRequest(_WSGIRequest): - session: SessionStore - user: t.Union[User, AnonymousUser] + session: "SessionStore" + user: t.Union["User", AnonymousUser] # pylint: disable-next=missing-class-docstring class HttpRequest(_HttpRequest): - session: SessionStore - user: t.Union[User, AnonymousUser] + session: "SessionStore" + user: t.Union["User", AnonymousUser] # pylint: disable-next=missing-class-docstring,abstract-method -class Request(_Request, t.Generic[AnyUser]): - session: SessionStore +class Request(_Request, t.Generic["AnyUser"]): + session: "SessionStore" data: t.Any - def __init__(self, user_class: t.Type[AnyUser], *args, **kwargs): + def __init__(self, user_class: t.Type["AnyUser"], *args, **kwargs): super().__init__(*args, **kwargs) self.user_class = user_class @@ -54,7 +48,7 @@ def query_params(self) -> t.Dict[str, str]: # type: ignore[override] @property def user(self): - return t.cast(t.Union[AnyUser, AnonymousUser], super().user) + return t.cast(t.Union["AnyUser", AnonymousUser], super().user) @user.setter def user(self, value): @@ -72,21 +66,30 @@ def anon_user(self): @property def auth_user(self): """The authenticated user that made the request.""" - return t.cast(AnyUser, self.user) + return t.cast("AnyUser", self.user) @property def teacher_user(self): """The authenticated teacher-user that made the request.""" + # pylint: disable-next=import-outside-toplevel + from .user.models import TeacherUser + return self.auth_user.as_type(TeacherUser) @property def school_teacher_user(self): """The authenticated school-teacher-user that made the request.""" + # pylint: disable-next=import-outside-toplevel + from .user.models import SchoolTeacherUser + return self.auth_user.as_type(SchoolTeacherUser) @property def admin_school_teacher_user(self): """The authenticated admin-school-teacher-user that made the request.""" + # pylint: disable-next=import-outside-toplevel + from .user.models import AdminSchoolTeacherUser + return self.auth_user.as_type(AdminSchoolTeacherUser) @property @@ -94,21 +97,33 @@ def non_admin_school_teacher_user(self): """ The authenticated non-admin-school-teacher-user that made the request. """ + # pylint: disable-next=import-outside-toplevel + from .user.models import NonAdminSchoolTeacherUser + return self.auth_user.as_type(NonAdminSchoolTeacherUser) @property def non_school_teacher_user(self): """The authenticated non-school-teacher-user that made the request.""" + # pylint: disable-next=import-outside-toplevel + from .user.models import NonSchoolTeacherUser + return self.auth_user.as_type(NonSchoolTeacherUser) @property def student_user(self): """The authenticated student-user that made the request.""" + # pylint: disable-next=import-outside-toplevel + from .user.models import StudentUser + return self.auth_user.as_type(StudentUser) @property def indy_user(self): """The authenticated independent-user that made the request.""" + # pylint: disable-next=import-outside-toplevel + from .user.models import IndependentUser + return self.auth_user.as_type(IndependentUser) @property From 4c8e557b9c76a0ee7947253dacfd8de0a72983d6 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Mon, 4 Nov 2024 17:15:02 +0000 Subject: [PATCH 02/56] fix issues --- codeforlife/request.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/codeforlife/request.py b/codeforlife/request.py index 6c8e342b..2e9d18a7 100644 --- a/codeforlife/request.py +++ b/codeforlife/request.py @@ -18,7 +18,7 @@ from .user.models import User from .user.models.session import SessionStore - AnyUser = t.TypeVar("AnyUser", bound=User) +AnyUser = t.TypeVar("AnyUser") # pylint: disable-next=missing-class-docstring @@ -34,11 +34,11 @@ class HttpRequest(_HttpRequest): # pylint: disable-next=missing-class-docstring,abstract-method -class Request(_Request, t.Generic["AnyUser"]): +class Request(_Request, t.Generic[AnyUser]): session: "SessionStore" data: t.Any - def __init__(self, user_class: t.Type["AnyUser"], *args, **kwargs): + def __init__(self, user_class: t.Type[AnyUser], *args, **kwargs): super().__init__(*args, **kwargs) self.user_class = user_class @@ -48,11 +48,18 @@ def query_params(self) -> t.Dict[str, str]: # type: ignore[override] @property def user(self): - return t.cast(t.Union["AnyUser", AnonymousUser], super().user) + return t.cast(t.Union[AnyUser, AnonymousUser], super().user) @user.setter def user(self, value): - if isinstance(value, User) and not isinstance(value, self.user_class): + # pylint: disable-next=import-outside-toplevel + from .user.models import User + + if ( + isinstance(value, User) + and issubclass(self.user_class, User) + and not isinstance(value, self.user_class) + ): value = value.as_type(self.user_class) self._user = value @@ -66,7 +73,7 @@ def anon_user(self): @property def auth_user(self): """The authenticated user that made the request.""" - return t.cast("AnyUser", self.user) + return t.cast(AnyUser, self.user) @property def teacher_user(self): From 77bd91c8c4c99915576c1364da64dbc4c4309e33 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Mon, 4 Nov 2024 17:25:29 +0000 Subject: [PATCH 03/56] fix type imports --- codeforlife/tests/api.py | 87 ++++++++++++++++++++++++++++------------ 1 file changed, 62 insertions(+), 25 deletions(-) diff --git a/codeforlife/tests/api.py b/codeforlife/tests/api.py index 89c771be..c757375b 100644 --- a/codeforlife/tests/api.py +++ b/codeforlife/tests/api.py @@ -13,26 +13,28 @@ from rest_framework.test import APIClient as _APIClient from ..types import DataDict, JsonDict -from ..user.models import AdminSchoolTeacherUser -from ..user.models import AnyUser as RequestUser -from ..user.models import ( - AuthFactor, - IndependentUser, - NonAdminSchoolTeacherUser, - NonSchoolTeacherUser, - SchoolTeacherUser, - StudentUser, - TeacherUser, - TypedUser, - User, -) from .api_request_factory import APIRequestFactory from .test import TestCase -LoginUser = t.TypeVar("LoginUser", bound=User) - - -class APIClient(_APIClient, t.Generic[RequestUser]): +if t.TYPE_CHECKING: + from ..user.models import ( + AdminSchoolTeacherUser, + AuthFactor, + IndependentUser, + NonAdminSchoolTeacherUser, + NonSchoolTeacherUser, + SchoolTeacherUser, + StudentUser, + TeacherUser, + TypedUser, + User, + ) + + RequestUser = t.TypeVar("RequestUser", bound=User) + LoginUser = t.TypeVar("LoginUser", bound=User) + + +class APIClient(_APIClient, t.Generic["RequestUser"]): """Base API client to be inherited by all other API clients.""" _test_case: "APITestCase[RequestUser]" @@ -56,7 +58,7 @@ def __init__( ) @classmethod - def get_request_user_class(cls) -> t.Type[RequestUser]: + def get_request_user_class(cls) -> t.Type["RequestUser"]: """Get the request's user class. Returns: @@ -118,7 +120,10 @@ def _make_assertions(): # Login Helpers # -------------------------------------------------------------------------- - def _login_user_type(self, user_type: t.Type[LoginUser], **credentials): + def _login_user_type(self, user_type: t.Type["LoginUser"], **credentials): + # pylint: disable-next=import-outside-toplevel + from ..user.models import AuthFactor + # Logout current user (if any) before logging in next user. self.logout() assert super().login( @@ -135,7 +140,7 @@ def _login_user_type(self, user_type: t.Type[LoginUser], **credentials): with patch.object(timezone, "now", return_value=now): assert super().login( request=self.request_factory.post( - user=t.cast(RequestUser, user) + user=t.cast("RequestUser", user) ), otp=otp, ), f'Failed to login with OTP "{otp}" at {now}.' @@ -162,6 +167,9 @@ def login_teacher(self, email: str, password: str = "password"): Returns: The teacher-user. """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import TeacherUser + return self._login_user_type( TeacherUser, email=email, password=password ) @@ -176,6 +184,9 @@ def login_school_teacher(self, email: str, password: str = "password"): Returns: The school-teacher-user. """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import SchoolTeacherUser + return self._login_user_type( SchoolTeacherUser, email=email, password=password ) @@ -192,6 +203,9 @@ def login_admin_school_teacher( Returns: The admin-school-teacher-user. """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import AdminSchoolTeacherUser + return self._login_user_type( AdminSchoolTeacherUser, email=email, password=password ) @@ -208,6 +222,9 @@ def login_non_admin_school_teacher( Returns: The non-admin-school-teacher-user. """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import NonAdminSchoolTeacherUser + return self._login_user_type( NonAdminSchoolTeacherUser, email=email, password=password ) @@ -222,6 +239,9 @@ def login_non_school_teacher(self, email: str, password: str = "password"): Returns: The non-school-teacher-user. """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import NonSchoolTeacherUser + return self._login_user_type( NonSchoolTeacherUser, email=email, password=password ) @@ -239,6 +259,9 @@ def login_student( Returns: The student-user. """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import StudentUser + return self._login_user_type( StudentUser, first_name=first_name, @@ -256,11 +279,14 @@ def login_indy(self, email: str, password: str = "password"): Returns: The independent-user. """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import IndependentUser + return self._login_user_type( IndependentUser, email=email, password=password ) - def login_as(self, user: TypedUser, password: str = "password"): + def login_as(self, user: "TypedUser", password: str = "password"): """Log in as a user. The user instance needs to be a user proxy in order to know which credentials are required. @@ -268,6 +294,17 @@ def login_as(self, user: TypedUser, password: str = "password"): user: The user to log in as. password: The user's password. """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import ( + AdminSchoolTeacherUser, + IndependentUser, + NonAdminSchoolTeacherUser, + NonSchoolTeacherUser, + SchoolTeacherUser, + StudentUser, + TeacherUser, + ) + auth_user = None if isinstance(user, AdminSchoolTeacherUser): @@ -487,14 +524,14 @@ def options( # type: ignore[override] # pylint: enable=too-many-arguments,redefined-builtin -class APITestCase(TestCase, t.Generic[RequestUser]): +class APITestCase(TestCase, t.Generic["RequestUser"]): """Base API test case to be inherited by all other API test cases.""" - client: APIClient[RequestUser] - client_class: t.Type[APIClient[RequestUser]] = APIClient + client: APIClient["RequestUser"] + client_class: t.Type[APIClient["RequestUser"]] = APIClient @classmethod - def get_request_user_class(cls) -> t.Type[RequestUser]: + def get_request_user_class(cls) -> t.Type["RequestUser"]: """Get the request's user class. Returns: From 4922af6de25ee4098e2ad186b03ac03da43efb1f Mon Sep 17 00:00:00 2001 From: SKairinos Date: Tue, 5 Nov 2024 10:53:38 +0000 Subject: [PATCH 04/56] fix: type as vars --- codeforlife/request.py | 4 +++- codeforlife/tests/api.py | 32 ++++++++++++-------------------- 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/codeforlife/request.py b/codeforlife/request.py index 2e9d18a7..34a0823f 100644 --- a/codeforlife/request.py +++ b/codeforlife/request.py @@ -18,7 +18,9 @@ from .user.models import User from .user.models.session import SessionStore -AnyUser = t.TypeVar("AnyUser") + AnyUser = t.TypeVar("AnyUser", bound=User) +else: + AnyUser = t.TypeVar("AnyUser") # pylint: disable-next=missing-class-docstring diff --git a/codeforlife/tests/api.py b/codeforlife/tests/api.py index c757375b..d2c08756 100644 --- a/codeforlife/tests/api.py +++ b/codeforlife/tests/api.py @@ -17,24 +17,16 @@ from .test import TestCase if t.TYPE_CHECKING: - from ..user.models import ( - AdminSchoolTeacherUser, - AuthFactor, - IndependentUser, - NonAdminSchoolTeacherUser, - NonSchoolTeacherUser, - SchoolTeacherUser, - StudentUser, - TeacherUser, - TypedUser, - User, - ) + from ..user.models import TypedUser, User RequestUser = t.TypeVar("RequestUser", bound=User) LoginUser = t.TypeVar("LoginUser", bound=User) +else: + RequestUser = t.TypeVar("RequestUser") + LoginUser = t.TypeVar("LoginUser") -class APIClient(_APIClient, t.Generic["RequestUser"]): +class APIClient(_APIClient, t.Generic[RequestUser]): """Base API client to be inherited by all other API clients.""" _test_case: "APITestCase[RequestUser]" @@ -58,7 +50,7 @@ def __init__( ) @classmethod - def get_request_user_class(cls) -> t.Type["RequestUser"]: + def get_request_user_class(cls) -> t.Type[RequestUser]: """Get the request's user class. Returns: @@ -120,7 +112,7 @@ def _make_assertions(): # Login Helpers # -------------------------------------------------------------------------- - def _login_user_type(self, user_type: t.Type["LoginUser"], **credentials): + def _login_user_type(self, user_type: t.Type[LoginUser], **credentials): # pylint: disable-next=import-outside-toplevel from ..user.models import AuthFactor @@ -140,7 +132,7 @@ def _login_user_type(self, user_type: t.Type["LoginUser"], **credentials): with patch.object(timezone, "now", return_value=now): assert super().login( request=self.request_factory.post( - user=t.cast("RequestUser", user) + user=t.cast(RequestUser, user) ), otp=otp, ), f'Failed to login with OTP "{otp}" at {now}.' @@ -524,14 +516,14 @@ def options( # type: ignore[override] # pylint: enable=too-many-arguments,redefined-builtin -class APITestCase(TestCase, t.Generic["RequestUser"]): +class APITestCase(TestCase, t.Generic[RequestUser]): """Base API test case to be inherited by all other API test cases.""" - client: APIClient["RequestUser"] - client_class: t.Type[APIClient["RequestUser"]] = APIClient + client: APIClient[RequestUser] + client_class: t.Type[APIClient[RequestUser]] = APIClient @classmethod - def get_request_user_class(cls) -> t.Type["RequestUser"]: + def get_request_user_class(cls) -> t.Type[RequestUser]: """Get the request's user class. Returns: From 84f4895c6d53e71c541d6484874e4f1650660291 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Tue, 5 Nov 2024 11:04:02 +0000 Subject: [PATCH 05/56] fix: type imports --- codeforlife/tests/api_request_factory.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/codeforlife/tests/api_request_factory.py b/codeforlife/tests/api_request_factory.py index fabb0a70..61311c12 100644 --- a/codeforlife/tests/api_request_factory.py +++ b/codeforlife/tests/api_request_factory.py @@ -15,7 +15,13 @@ from rest_framework.test import APIRequestFactory as _APIRequestFactory from ..request import Request -from ..user.models import AnyUser + +if t.TYPE_CHECKING: + from ..user.models import User + + AnyUser = t.TypeVar("AnyUser", bound=User) +else: + AnyUser = t.TypeVar("AnyUser") class APIRequestFactory(_APIRequestFactory, t.Generic[AnyUser]): From e38004997edeec27a2da340fb0f543d2e2320ca7 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Tue, 5 Nov 2024 11:41:43 +0000 Subject: [PATCH 06/56] fix: imports --- codeforlife/serializers/base.py | 8 +++++++- codeforlife/serializers/model.py | 8 +++++++- codeforlife/tests/model_serializer.py | 9 ++++++++- codeforlife/tests/model_view_set.py | 8 +++++++- codeforlife/views/model.py | 6 +++++- 5 files changed, 34 insertions(+), 5 deletions(-) diff --git a/codeforlife/serializers/base.py b/codeforlife/serializers/base.py index a49d0a2c..5c8fbf0c 100644 --- a/codeforlife/serializers/base.py +++ b/codeforlife/serializers/base.py @@ -11,7 +11,13 @@ from rest_framework.serializers import BaseSerializer as _BaseSerializer from ..request import Request -from ..user.models import AnyUser as RequestUser + +if t.TYPE_CHECKING: + from ..user.models import User + + RequestUser = t.TypeVar("RequestUser", bound=User) +else: + RequestUser = t.TypeVar("RequestUser") # pylint: disable-next=abstract-method diff --git a/codeforlife/serializers/model.py b/codeforlife/serializers/model.py index 1300483c..ad99c180 100644 --- a/codeforlife/serializers/model.py +++ b/codeforlife/serializers/model.py @@ -13,9 +13,15 @@ from rest_framework.serializers import ValidationError as _ValidationError from ..types import DataDict, OrderedDataDict -from ..user.models import AnyUser as RequestUser from .base import BaseSerializer +if t.TYPE_CHECKING: + from ..user.models import User + + RequestUser = t.TypeVar("RequestUser", bound=User) +else: + RequestUser = t.TypeVar("RequestUser") + AnyModel = t.TypeVar("AnyModel", bound=Model) diff --git a/codeforlife/tests/model_serializer.py b/codeforlife/tests/model_serializer.py index 8c6cc7bf..232e897b 100644 --- a/codeforlife/tests/model_serializer.py +++ b/codeforlife/tests/model_serializer.py @@ -15,10 +15,17 @@ from ..serializers import ModelListSerializer, ModelSerializer from ..types import DataDict -from ..user.models import AnyUser as RequestUser from .api_request_factory import APIRequestFactory from .test import TestCase +if t.TYPE_CHECKING: + from ..user.models import User + + RequestUser = t.TypeVar("RequestUser", bound=User) +else: + RequestUser = t.TypeVar("RequestUser") + + AnyModel = t.TypeVar("AnyModel", bound=Model) diff --git a/codeforlife/tests/model_view_set.py b/codeforlife/tests/model_view_set.py index 855bbf02..bd944a9f 100644 --- a/codeforlife/tests/model_view_set.py +++ b/codeforlife/tests/model_view_set.py @@ -19,10 +19,16 @@ from ..permissions import Permission from ..serializers import BaseSerializer from ..types import DataDict, JsonDict, KwArgs -from ..user.models import AnyUser as RequestUser from ..views import ModelViewSet from .api import APIClient, APITestCase +if t.TYPE_CHECKING: + from ..user.models import User + + RequestUser = t.TypeVar("RequestUser", bound=User) +else: + RequestUser = t.TypeVar("RequestUser") + AnyModel = t.TypeVar("AnyModel", bound=Model) # pylint: disable=no-member,too-many-arguments diff --git a/codeforlife/views/model.py b/codeforlife/views/model.py index 14ed52fa..71516997 100644 --- a/codeforlife/views/model.py +++ b/codeforlife/views/model.py @@ -16,7 +16,6 @@ from ..permissions import Permission from ..request import Request from ..types import KwArgs -from ..user.models import AnyUser as RequestUser from .api import APIView from .decorators import action @@ -24,6 +23,9 @@ if t.TYPE_CHECKING: # pragma: no cover from ..serializers import ModelListSerializer, ModelSerializer + from ..user.models import User + + RequestUser = t.TypeVar("RequestUser", bound=User) # NOTE: This raises an error during runtime. # pylint: disable-next=too-few-public-methods @@ -31,6 +33,8 @@ class _ModelViewSet(DrfModelViewSet[AnyModel], t.Generic[AnyModel]): pass else: + RequestUser = t.TypeVar("RequestUser") + # pylint: disable-next=too-many-ancestors class _ModelViewSet(DrfModelViewSet, t.Generic[AnyModel]): pass From c3312e6b2f30db23a3f3221323fd0ef0f46f5470 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Tue, 5 Nov 2024 11:48:26 +0000 Subject: [PATCH 07/56] fix: imports --- codeforlife/serializers/__init__.py | 4 ++-- codeforlife/views/api.py | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/codeforlife/serializers/__init__.py b/codeforlife/serializers/__init__.py index 30c7968d..44782491 100644 --- a/codeforlife/serializers/__init__.py +++ b/codeforlife/serializers/__init__.py @@ -3,5 +3,5 @@ Created on 20/01/2024 at 11:19:12(+00:00). """ -from .base import * -from .model import * +from .base import BaseSerializer +from .model import ModelListSerializer, ModelSerializer diff --git a/codeforlife/views/api.py b/codeforlife/views/api.py index af7ad2db..07c0d459 100644 --- a/codeforlife/views/api.py +++ b/codeforlife/views/api.py @@ -8,7 +8,13 @@ from rest_framework.views import APIView as _APIView from ..request import Request -from ..user.models import AnyUser as RequestUser + +if t.TYPE_CHECKING: + from ..user.models import User + + RequestUser = t.TypeVar("RequestUser", bound=User) +else: + RequestUser = t.TypeVar("RequestUser") # pylint: disable-next=missing-class-docstring From 1dd44d07d71f88568dc075dbce4a53d589cd8221 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Tue, 5 Nov 2024 13:38:37 +0000 Subject: [PATCH 08/56] base request --- codeforlife/request.py | 56 ++++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/codeforlife/request.py b/codeforlife/request.py index 34a0823f..ce000567 100644 --- a/codeforlife/request.py +++ b/codeforlife/request.py @@ -36,18 +36,42 @@ class HttpRequest(_HttpRequest): # pylint: disable-next=missing-class-docstring,abstract-method -class Request(_Request, t.Generic[AnyUser]): - session: "SessionStore" +class BaseRequest(_Request, t.Generic[AnyUser]): data: t.Any - - def __init__(self, user_class: t.Type[AnyUser], *args, **kwargs): - super().__init__(*args, **kwargs) - self.user_class = user_class + session: "SessionStore" + user: t.Union[AnyUser, AnonymousUser] @property def query_params(self) -> t.Dict[str, str]: # type: ignore[override] return super().query_params + @property + def anon_user(self): + """The anonymous user that made the request.""" + return t.cast(AnonymousUser, self.user) + + @property + def auth_user(self): + """The authenticated user that made the request.""" + return t.cast(AnyUser, self.user) + + @property + def json_dict(self): + """The data as a json dictionary.""" + return t.cast(JsonDict, self.data) + + @property + def json_list(self): + """The data as a json list.""" + return t.cast(JsonList, self.data) + + +# pylint: disable-next=missing-class-docstring,abstract-method +class Request(BaseRequest[AnyUser], t.Generic[AnyUser]): + def __init__(self, user_class: t.Type[AnyUser], *args, **kwargs): + super().__init__(*args, **kwargs) + self.user_class = user_class + @property def user(self): return t.cast(t.Union[AnyUser, AnonymousUser], super().user) @@ -67,16 +91,6 @@ def user(self, value): self._user = value self._request.user = value - @property - def anon_user(self): - """The anonymous user that made the request.""" - return t.cast(AnonymousUser, self.user) - - @property - def auth_user(self): - """The authenticated user that made the request.""" - return t.cast(AnyUser, self.user) - @property def teacher_user(self): """The authenticated teacher-user that made the request.""" @@ -134,13 +148,3 @@ def indy_user(self): from .user.models import IndependentUser return self.auth_user.as_type(IndependentUser) - - @property - def json_dict(self): - """The data as a json dictionary.""" - return t.cast(JsonDict, self.data) - - @property - def json_list(self): - """The data as a json list.""" - return t.cast(JsonList, self.data) From e0a30611f13d07138408678c5ffd9e204c7c1ba8 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Tue, 5 Nov 2024 13:55:15 +0000 Subject: [PATCH 09/56] fix: types --- codeforlife/request.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/codeforlife/request.py b/codeforlife/request.py index ce000567..41fa0565 100644 --- a/codeforlife/request.py +++ b/codeforlife/request.py @@ -7,7 +7,7 @@ import typing as t -from django.contrib.auth.models import AnonymousUser +from django.contrib.auth.models import AbstractBaseUser, AnonymousUser from django.core.handlers.wsgi import WSGIRequest as _WSGIRequest from django.http import HttpRequest as _HttpRequest from rest_framework.request import Request as _Request @@ -22,24 +22,25 @@ else: AnyUser = t.TypeVar("AnyUser") +AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) + # pylint: disable-next=missing-class-docstring -class WSGIRequest(_WSGIRequest): +class WSGIRequest(_WSGIRequest, t.Generic[AnyUser]): session: "SessionStore" - user: t.Union["User", AnonymousUser] + user: t.Union[AnyUser, AnonymousUser] # pylint: disable-next=missing-class-docstring -class HttpRequest(_HttpRequest): +class HttpRequest(_HttpRequest, t.Generic[AnyUser]): session: "SessionStore" - user: t.Union["User", AnonymousUser] + user: t.Union[AnyUser, AnonymousUser] # pylint: disable-next=missing-class-docstring,abstract-method -class BaseRequest(_Request, t.Generic[AnyUser]): +class BaseRequest(_Request, t.Generic[AnyAbstractBaseUser]): data: t.Any - session: "SessionStore" - user: t.Union[AnyUser, AnonymousUser] + user: t.Union[AnyAbstractBaseUser, AnonymousUser] @property def query_params(self) -> t.Dict[str, str]: # type: ignore[override] @@ -53,7 +54,7 @@ def anon_user(self): @property def auth_user(self): """The authenticated user that made the request.""" - return t.cast(AnyUser, self.user) + return t.cast(AnyAbstractBaseUser, self.user) @property def json_dict(self): @@ -68,6 +69,8 @@ def json_list(self): # pylint: disable-next=missing-class-docstring,abstract-method class Request(BaseRequest[AnyUser], t.Generic[AnyUser]): + session: "SessionStore" + def __init__(self, user_class: t.Type[AnyUser], *args, **kwargs): super().__init__(*args, **kwargs) self.user_class = user_class From 40076c19e58401a56b3794996efb4179d54da186 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Tue, 5 Nov 2024 13:57:52 +0000 Subject: [PATCH 10/56] ignore duplicate code --- codeforlife/request.py | 1 + codeforlife/tests/api.py | 1 + codeforlife/tests/api_request_factory.py | 1 + codeforlife/tests/model_serializer.py | 1 + codeforlife/tests/model_view_set.py | 1 + 5 files changed, 5 insertions(+) diff --git a/codeforlife/request.py b/codeforlife/request.py index 41fa0565..39cd13f9 100644 --- a/codeforlife/request.py +++ b/codeforlife/request.py @@ -14,6 +14,7 @@ from .types import JsonDict, JsonList +# pylint: disable-next=duplicate-code if t.TYPE_CHECKING: from .user.models import User from .user.models.session import SessionStore diff --git a/codeforlife/tests/api.py b/codeforlife/tests/api.py index d2c08756..e70fb1c3 100644 --- a/codeforlife/tests/api.py +++ b/codeforlife/tests/api.py @@ -16,6 +16,7 @@ from .api_request_factory import APIRequestFactory from .test import TestCase +# pylint: disable-next=duplicate-code if t.TYPE_CHECKING: from ..user.models import TypedUser, User diff --git a/codeforlife/tests/api_request_factory.py b/codeforlife/tests/api_request_factory.py index 61311c12..7565c021 100644 --- a/codeforlife/tests/api_request_factory.py +++ b/codeforlife/tests/api_request_factory.py @@ -16,6 +16,7 @@ from ..request import Request +# pylint: disable-next=duplicate-code if t.TYPE_CHECKING: from ..user.models import User diff --git a/codeforlife/tests/model_serializer.py b/codeforlife/tests/model_serializer.py index 232e897b..dcda7bba 100644 --- a/codeforlife/tests/model_serializer.py +++ b/codeforlife/tests/model_serializer.py @@ -18,6 +18,7 @@ from .api_request_factory import APIRequestFactory from .test import TestCase +# pylint: disable-next=duplicate-code if t.TYPE_CHECKING: from ..user.models import User diff --git a/codeforlife/tests/model_view_set.py b/codeforlife/tests/model_view_set.py index bd944a9f..3a373021 100644 --- a/codeforlife/tests/model_view_set.py +++ b/codeforlife/tests/model_view_set.py @@ -22,6 +22,7 @@ from ..views import ModelViewSet from .api import APIClient, APITestCase +# pylint: disable-next=duplicate-code if t.TYPE_CHECKING: from ..user.models import User From 67e6b17810d65b0971fdf840d4ec724b6a5eb5b6 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Tue, 5 Nov 2024 13:59:23 +0000 Subject: [PATCH 11/56] disable duplicate code --- codeforlife/serializers/base.py | 1 + codeforlife/serializers/model.py | 1 + codeforlife/views/api.py | 1 + codeforlife/views/model.py | 1 + 4 files changed, 4 insertions(+) diff --git a/codeforlife/serializers/base.py b/codeforlife/serializers/base.py index 5c8fbf0c..8f43af79 100644 --- a/codeforlife/serializers/base.py +++ b/codeforlife/serializers/base.py @@ -12,6 +12,7 @@ from ..request import Request +# pylint: disable-next=duplicate-code if t.TYPE_CHECKING: from ..user.models import User diff --git a/codeforlife/serializers/model.py b/codeforlife/serializers/model.py index ad99c180..2a1476b0 100644 --- a/codeforlife/serializers/model.py +++ b/codeforlife/serializers/model.py @@ -15,6 +15,7 @@ from ..types import DataDict, OrderedDataDict from .base import BaseSerializer +# pylint: disable-next=duplicate-code if t.TYPE_CHECKING: from ..user.models import User diff --git a/codeforlife/views/api.py b/codeforlife/views/api.py index 07c0d459..fe92dfe1 100644 --- a/codeforlife/views/api.py +++ b/codeforlife/views/api.py @@ -9,6 +9,7 @@ from ..request import Request +# pylint: disable-next=duplicate-code if t.TYPE_CHECKING: from ..user.models import User diff --git a/codeforlife/views/model.py b/codeforlife/views/model.py index 71516997..3f5f6a39 100644 --- a/codeforlife/views/model.py +++ b/codeforlife/views/model.py @@ -21,6 +21,7 @@ AnyModel = t.TypeVar("AnyModel", bound=Model) +# pylint: disable-next=duplicate-code if t.TYPE_CHECKING: # pragma: no cover from ..serializers import ModelListSerializer, ModelSerializer from ..user.models import User From 9c0463d546ed1dde40cf15fb69f7fb20b32d85d5 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Tue, 5 Nov 2024 16:29:06 +0000 Subject: [PATCH 12/56] split request objects --- codeforlife/request/__init__.py | 8 +++++ codeforlife/{request.py => request/drf.py} | 40 +++++++--------------- codeforlife/request/http.py | 32 +++++++++++++++++ codeforlife/request/wsgi.py | 32 +++++++++++++++++ 4 files changed, 85 insertions(+), 27 deletions(-) create mode 100644 codeforlife/request/__init__.py rename codeforlife/{request.py => request/drf.py} (77%) create mode 100644 codeforlife/request/http.py create mode 100644 codeforlife/request/wsgi.py diff --git a/codeforlife/request/__init__.py b/codeforlife/request/__init__.py new file mode 100644 index 00000000..a8ed8a0c --- /dev/null +++ b/codeforlife/request/__init__.py @@ -0,0 +1,8 @@ +""" +© Ocado Group +Created on 05/11/2024 at 14:40:32(+00:00). +""" + +from .drf import BaseRequest, Request +from .http import BaseHttpRequest, HttpRequest +from .wsgi import BaseWSGIRequest, WSGIRequest diff --git a/codeforlife/request.py b/codeforlife/request/drf.py similarity index 77% rename from codeforlife/request.py rename to codeforlife/request/drf.py index 39cd13f9..4c30da03 100644 --- a/codeforlife/request.py +++ b/codeforlife/request/drf.py @@ -1,23 +1,21 @@ """ © Ocado Group -Created on 19/02/2024 at 15:28:22(+00:00). +Created on 05/11/2024 at 14:41:58(+00:00). -Override default request objects. +Custom Request which hints to our custom types. """ import typing as t from django.contrib.auth.models import AbstractBaseUser, AnonymousUser -from django.core.handlers.wsgi import WSGIRequest as _WSGIRequest -from django.http import HttpRequest as _HttpRequest from rest_framework.request import Request as _Request -from .types import JsonDict, JsonList +from ..types import JsonDict, JsonList # pylint: disable-next=duplicate-code if t.TYPE_CHECKING: - from .user.models import User - from .user.models.session import SessionStore + from ..user.models import User + from ..user.models.session import SessionStore AnyUser = t.TypeVar("AnyUser", bound=User) else: @@ -26,18 +24,6 @@ AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) -# pylint: disable-next=missing-class-docstring -class WSGIRequest(_WSGIRequest, t.Generic[AnyUser]): - session: "SessionStore" - user: t.Union[AnyUser, AnonymousUser] - - -# pylint: disable-next=missing-class-docstring -class HttpRequest(_HttpRequest, t.Generic[AnyUser]): - session: "SessionStore" - user: t.Union[AnyUser, AnonymousUser] - - # pylint: disable-next=missing-class-docstring,abstract-method class BaseRequest(_Request, t.Generic[AnyAbstractBaseUser]): data: t.Any @@ -83,7 +69,7 @@ def user(self): @user.setter def user(self, value): # pylint: disable-next=import-outside-toplevel - from .user.models import User + from ..user.models import User if ( isinstance(value, User) @@ -99,7 +85,7 @@ def user(self, value): def teacher_user(self): """The authenticated teacher-user that made the request.""" # pylint: disable-next=import-outside-toplevel - from .user.models import TeacherUser + from ..user.models import TeacherUser return self.auth_user.as_type(TeacherUser) @@ -107,7 +93,7 @@ def teacher_user(self): def school_teacher_user(self): """The authenticated school-teacher-user that made the request.""" # pylint: disable-next=import-outside-toplevel - from .user.models import SchoolTeacherUser + from ..user.models import SchoolTeacherUser return self.auth_user.as_type(SchoolTeacherUser) @@ -115,7 +101,7 @@ def school_teacher_user(self): def admin_school_teacher_user(self): """The authenticated admin-school-teacher-user that made the request.""" # pylint: disable-next=import-outside-toplevel - from .user.models import AdminSchoolTeacherUser + from ..user.models import AdminSchoolTeacherUser return self.auth_user.as_type(AdminSchoolTeacherUser) @@ -125,7 +111,7 @@ def non_admin_school_teacher_user(self): The authenticated non-admin-school-teacher-user that made the request. """ # pylint: disable-next=import-outside-toplevel - from .user.models import NonAdminSchoolTeacherUser + from ..user.models import NonAdminSchoolTeacherUser return self.auth_user.as_type(NonAdminSchoolTeacherUser) @@ -133,7 +119,7 @@ def non_admin_school_teacher_user(self): def non_school_teacher_user(self): """The authenticated non-school-teacher-user that made the request.""" # pylint: disable-next=import-outside-toplevel - from .user.models import NonSchoolTeacherUser + from ..user.models import NonSchoolTeacherUser return self.auth_user.as_type(NonSchoolTeacherUser) @@ -141,7 +127,7 @@ def non_school_teacher_user(self): def student_user(self): """The authenticated student-user that made the request.""" # pylint: disable-next=import-outside-toplevel - from .user.models import StudentUser + from ..user.models import StudentUser return self.auth_user.as_type(StudentUser) @@ -149,6 +135,6 @@ def student_user(self): def indy_user(self): """The authenticated independent-user that made the request.""" # pylint: disable-next=import-outside-toplevel - from .user.models import IndependentUser + from ..user.models import IndependentUser return self.auth_user.as_type(IndependentUser) diff --git a/codeforlife/request/http.py b/codeforlife/request/http.py new file mode 100644 index 00000000..70678f8a --- /dev/null +++ b/codeforlife/request/http.py @@ -0,0 +1,32 @@ +""" +© Ocado Group +Created on 05/11/2024 at 14:41:58(+00:00). + +Custom HttpRequest which hints to our custom types. +""" + +import typing as t + +from django.contrib.auth.models import AbstractBaseUser, AnonymousUser +from django.http import HttpRequest as _HttpRequest + +# pylint: disable-next=duplicate-code +if t.TYPE_CHECKING: + from ..user.models import User + from ..user.models.session import SessionStore + + AnyUser = t.TypeVar("AnyUser", bound=User) +else: + AnyUser = t.TypeVar("AnyUser") + +AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) + + +# pylint: disable-next=missing-class-docstring +class BaseHttpRequest(_HttpRequest, t.Generic[AnyAbstractBaseUser]): + user: t.Union[AnyAbstractBaseUser, AnonymousUser] + + +# pylint: disable-next=missing-class-docstring +class HttpRequest(BaseHttpRequest[AnyUser], t.Generic[AnyUser]): + session: "SessionStore" diff --git a/codeforlife/request/wsgi.py b/codeforlife/request/wsgi.py new file mode 100644 index 00000000..b9764c8e --- /dev/null +++ b/codeforlife/request/wsgi.py @@ -0,0 +1,32 @@ +""" +© Ocado Group +Created on 05/11/2024 at 14:41:58(+00:00). + +Custom WSGIRequest which hints to our custom types. +""" + +import typing as t + +from django.contrib.auth.models import AbstractBaseUser, AnonymousUser +from django.core.handlers.wsgi import WSGIRequest as _WSGIRequest + +# pylint: disable-next=duplicate-code +if t.TYPE_CHECKING: + from ..user.models import User + from ..user.models.session import SessionStore + + AnyUser = t.TypeVar("AnyUser", bound=User) +else: + AnyUser = t.TypeVar("AnyUser") + +AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) + + +# pylint: disable-next=missing-class-docstring +class BaseWSGIRequest(_WSGIRequest, t.Generic[AnyAbstractBaseUser]): + user: t.Union[AnyAbstractBaseUser, AnonymousUser] + + +# pylint: disable-next=missing-class-docstring +class WSGIRequest(BaseWSGIRequest[AnyUser], t.Generic[AnyUser]): + session: "SessionStore" From 623f9952eacaf0a6ccb1f9d6ae3954e488c529b4 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Tue, 5 Nov 2024 16:38:17 +0000 Subject: [PATCH 13/56] session param --- codeforlife/request/drf.py | 9 +++++---- codeforlife/request/http.py | 9 ++++++--- codeforlife/request/wsgi.py | 9 ++++++--- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/codeforlife/request/drf.py b/codeforlife/request/drf.py index 4c30da03..7162e0ff 100644 --- a/codeforlife/request/drf.py +++ b/codeforlife/request/drf.py @@ -8,6 +8,7 @@ import typing as t from django.contrib.auth.models import AbstractBaseUser, AnonymousUser +from django.contrib.sessions.backends.db import SessionStore as DBStore from rest_framework.request import Request as _Request from ..types import JsonDict, JsonList @@ -21,12 +22,14 @@ else: AnyUser = t.TypeVar("AnyUser") +AnyDBStore = t.TypeVar("AnyDBStore", bound=DBStore) AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) # pylint: disable-next=missing-class-docstring,abstract-method -class BaseRequest(_Request, t.Generic[AnyAbstractBaseUser]): +class BaseRequest(_Request, t.Generic[AnyDBStore, AnyAbstractBaseUser]): data: t.Any + session: AnyDBStore user: t.Union[AnyAbstractBaseUser, AnonymousUser] @property @@ -55,9 +58,7 @@ def json_list(self): # pylint: disable-next=missing-class-docstring,abstract-method -class Request(BaseRequest[AnyUser], t.Generic[AnyUser]): - session: "SessionStore" - +class Request(BaseRequest["SessionStore", AnyUser], t.Generic[AnyUser]): def __init__(self, user_class: t.Type[AnyUser], *args, **kwargs): super().__init__(*args, **kwargs) self.user_class = user_class diff --git a/codeforlife/request/http.py b/codeforlife/request/http.py index 70678f8a..3595c373 100644 --- a/codeforlife/request/http.py +++ b/codeforlife/request/http.py @@ -8,6 +8,7 @@ import typing as t from django.contrib.auth.models import AbstractBaseUser, AnonymousUser +from django.contrib.sessions.backends.db import SessionStore as DBStore from django.http import HttpRequest as _HttpRequest # pylint: disable-next=duplicate-code @@ -19,14 +20,16 @@ else: AnyUser = t.TypeVar("AnyUser") +AnyDBStore = t.TypeVar("AnyDBStore", bound=DBStore) AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) # pylint: disable-next=missing-class-docstring -class BaseHttpRequest(_HttpRequest, t.Generic[AnyAbstractBaseUser]): +class BaseHttpRequest(_HttpRequest, t.Generic[AnyDBStore, AnyAbstractBaseUser]): + session: AnyDBStore user: t.Union[AnyAbstractBaseUser, AnonymousUser] # pylint: disable-next=missing-class-docstring -class HttpRequest(BaseHttpRequest[AnyUser], t.Generic[AnyUser]): - session: "SessionStore" +class HttpRequest(BaseHttpRequest["SessionStore", AnyUser], t.Generic[AnyUser]): + pass diff --git a/codeforlife/request/wsgi.py b/codeforlife/request/wsgi.py index b9764c8e..72b3ca69 100644 --- a/codeforlife/request/wsgi.py +++ b/codeforlife/request/wsgi.py @@ -8,6 +8,7 @@ import typing as t from django.contrib.auth.models import AbstractBaseUser, AnonymousUser +from django.contrib.sessions.backends.db import SessionStore as DBStore from django.core.handlers.wsgi import WSGIRequest as _WSGIRequest # pylint: disable-next=duplicate-code @@ -19,14 +20,16 @@ else: AnyUser = t.TypeVar("AnyUser") +AnyDBStore = t.TypeVar("AnyDBStore", bound=DBStore) AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) # pylint: disable-next=missing-class-docstring -class BaseWSGIRequest(_WSGIRequest, t.Generic[AnyAbstractBaseUser]): +class BaseWSGIRequest(_WSGIRequest, t.Generic[AnyDBStore, AnyAbstractBaseUser]): + session: AnyDBStore user: t.Union[AnyAbstractBaseUser, AnonymousUser] # pylint: disable-next=missing-class-docstring -class WSGIRequest(BaseWSGIRequest[AnyUser], t.Generic[AnyUser]): - session: "SessionStore" +class WSGIRequest(BaseWSGIRequest["SessionStore", AnyUser], t.Generic[AnyUser]): + pass From a3fc0c7df30a981ba98fc42ea6c398a5437ff38c Mon Sep 17 00:00:00 2001 From: SKairinos Date: Tue, 5 Nov 2024 16:43:29 +0000 Subject: [PATCH 14/56] disable duplicate code --- codeforlife/request/drf.py | 3 ++- codeforlife/request/http.py | 3 ++- codeforlife/request/wsgi.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/codeforlife/request/drf.py b/codeforlife/request/drf.py index 7162e0ff..44b34f8d 100644 --- a/codeforlife/request/drf.py +++ b/codeforlife/request/drf.py @@ -13,7 +13,7 @@ from ..types import JsonDict, JsonList -# pylint: disable-next=duplicate-code +# pylint: disable=duplicate-code if t.TYPE_CHECKING: from ..user.models import User from ..user.models.session import SessionStore @@ -24,6 +24,7 @@ AnyDBStore = t.TypeVar("AnyDBStore", bound=DBStore) AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) +# pylint: enable=duplicate-code # pylint: disable-next=missing-class-docstring,abstract-method diff --git a/codeforlife/request/http.py b/codeforlife/request/http.py index 3595c373..a64c13e7 100644 --- a/codeforlife/request/http.py +++ b/codeforlife/request/http.py @@ -11,7 +11,7 @@ from django.contrib.sessions.backends.db import SessionStore as DBStore from django.http import HttpRequest as _HttpRequest -# pylint: disable-next=duplicate-code +# pylint: disable=duplicate-code if t.TYPE_CHECKING: from ..user.models import User from ..user.models.session import SessionStore @@ -22,6 +22,7 @@ AnyDBStore = t.TypeVar("AnyDBStore", bound=DBStore) AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) +# pylint: enable=duplicate-code # pylint: disable-next=missing-class-docstring diff --git a/codeforlife/request/wsgi.py b/codeforlife/request/wsgi.py index 72b3ca69..dd4778d3 100644 --- a/codeforlife/request/wsgi.py +++ b/codeforlife/request/wsgi.py @@ -11,7 +11,7 @@ from django.contrib.sessions.backends.db import SessionStore as DBStore from django.core.handlers.wsgi import WSGIRequest as _WSGIRequest -# pylint: disable-next=duplicate-code +# pylint: disable=duplicate-code if t.TYPE_CHECKING: from ..user.models import User from ..user.models.session import SessionStore @@ -22,6 +22,7 @@ AnyDBStore = t.TypeVar("AnyDBStore", bound=DBStore) AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) +# pylint: enable=duplicate-code # pylint: disable-next=missing-class-docstring From 7342906c514e953a9dc53a3b43c30208fcdc35d5 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Tue, 5 Nov 2024 17:39:25 +0000 Subject: [PATCH 15/56] fix: abstract api request factory --- codeforlife/tests/api_request_factory.py | 119 ++++++++++++++--------- 1 file changed, 74 insertions(+), 45 deletions(-) diff --git a/codeforlife/tests/api_request_factory.py b/codeforlife/tests/api_request_factory.py index 7565c021..0758e1a5 100644 --- a/codeforlife/tests/api_request_factory.py +++ b/codeforlife/tests/api_request_factory.py @@ -5,6 +5,7 @@ import typing as t +from django.contrib.auth.models import AbstractBaseUser from django.core.handlers.wsgi import WSGIRequest from rest_framework.parsers import ( FileUploadParser, @@ -14,9 +15,9 @@ ) from rest_framework.test import APIRequestFactory as _APIRequestFactory -from ..request import Request +from ..request import BaseRequest, Request -# pylint: disable-next=duplicate-code +# pylint: disable=duplicate-code if t.TYPE_CHECKING: from ..user.models import User @@ -24,40 +25,33 @@ else: AnyUser = t.TypeVar("AnyUser") +AnyBaseRequest = t.TypeVar("AnyBaseRequest", bound=BaseRequest) +AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) +# pylint: enable=duplicate-code -class APIRequestFactory(_APIRequestFactory, t.Generic[AnyUser]): - """Custom API request factory that returns DRF's Request object.""" - - def __init__(self, user_class: t.Type[AnyUser], *args, **kwargs): - super().__init__(*args, **kwargs) - self.user_class = user_class - - @classmethod - def get_user_class(cls) -> t.Type[AnyUser]: - """Get the user class. - - Returns: - The user class. - """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] - def request(self, user: t.Optional[AnyUser] = None, **kwargs): - wsgi_request = t.cast(WSGIRequest, super().request(**kwargs)) +class BaseAPIRequestFactory( + _APIRequestFactory, t.Generic[AnyBaseRequest, AnyAbstractBaseUser] +): + """Custom API request factory that returns DRF's Request object.""" - request = Request( - self.user_class, - wsgi_request, - parsers=[ - JSONParser(), - FormParser(), - MultiPartParser(), - FileUploadParser(), - ], + def _init_request(self, wsgi_request: WSGIRequest): + return t.cast( + AnyBaseRequest, + BaseRequest( + wsgi_request, + parsers=[ + JSONParser(), + FormParser(), + MultiPartParser(), + FileUploadParser(), + ], + ), ) + def request(self, user: t.Optional[AnyAbstractBaseUser] = None, **kwargs): + wsgi_request = t.cast(WSGIRequest, super().request(**kwargs)) + request = self._init_request(wsgi_request) if user: # pylint: disable-next=attribute-defined-outside-init request.user = user @@ -72,11 +66,11 @@ def generic( data: t.Optional[str] = None, content_type: t.Optional[str] = None, secure: bool = True, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): return t.cast( - Request[AnyUser], + AnyBaseRequest, super().generic( method, path or "/", @@ -92,11 +86,11 @@ def get( # type: ignore[override] self, path: t.Optional[str] = None, data: t.Any = None, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): return t.cast( - Request[AnyUser], + AnyBaseRequest, super().get( path or "/", data, @@ -113,14 +107,14 @@ def post( # type: ignore[override] # pylint: disable-next=redefined-builtin format: t.Optional[str] = None, content_type: t.Optional[str] = None, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): if format is None and content_type is None: format = "json" return t.cast( - Request[AnyUser], + AnyBaseRequest, super().post( path or "/", data, @@ -139,14 +133,14 @@ def put( # type: ignore[override] # pylint: disable-next=redefined-builtin format: t.Optional[str] = None, content_type: t.Optional[str] = None, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): if format is None and content_type is None: format = "json" return t.cast( - Request[AnyUser], + AnyBaseRequest, super().put( path or "/", data, @@ -165,14 +159,14 @@ def patch( # type: ignore[override] # pylint: disable-next=redefined-builtin format: t.Optional[str] = None, content_type: t.Optional[str] = None, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): if format is None and content_type is None: format = "json" return t.cast( - Request[AnyUser], + AnyBaseRequest, super().patch( path or "/", data, @@ -191,14 +185,14 @@ def delete( # type: ignore[override] # pylint: disable-next=redefined-builtin format: t.Optional[str] = None, content_type: t.Optional[str] = None, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): if format is None and content_type is None: format = "json" return t.cast( - Request[AnyUser], + AnyBaseRequest, super().delete( path or "/", data, @@ -217,14 +211,14 @@ def options( # type: ignore[override] # pylint: disable-next=redefined-builtin format: t.Optional[str] = None, content_type: t.Optional[str] = None, - user: t.Optional[AnyUser] = None, + user: t.Optional[AnyAbstractBaseUser] = None, **extra ): if format is None and content_type is None: format = "json" return t.cast( - Request[AnyUser], + AnyBaseRequest, super().options( path or "/", data or {}, @@ -234,3 +228,38 @@ def options( # type: ignore[override] **extra, ), ) + + +class APIRequestFactory( + BaseAPIRequestFactory[Request[AnyUser], AnyUser], + t.Generic[AnyUser], +): + """Custom API request factory that returns DRF's Request object.""" + + def __init__(self, user_class: t.Type[AnyUser], *args, **kwargs): + super().__init__(*args, **kwargs) + self.user_class = user_class + + @classmethod + def get_user_class(cls) -> t.Type[AnyUser]: + """Get the user class. + + Returns: + The user class. + """ + # pylint: disable-next=no-member + return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] + 0 + ] + + def _init_request(self, wsgi_request): + return Request[AnyUser]( + self.user_class, + wsgi_request, + parsers=[ + JSONParser(), + FormParser(), + MultiPartParser(), + FileUploadParser(), + ], + ) From fb58a80778d99aadde76e0ddaa784f57a126f571 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Tue, 5 Nov 2024 17:40:08 +0000 Subject: [PATCH 16/56] import base api request factory --- codeforlife/tests/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeforlife/tests/__init__.py b/codeforlife/tests/__init__.py index fdaee248..9943912c 100644 --- a/codeforlife/tests/__init__.py +++ b/codeforlife/tests/__init__.py @@ -6,7 +6,7 @@ """ from .api import APIClient, APITestCase -from .api_request_factory import APIRequestFactory +from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory from .cron import CronTestCase from .model import ModelTestCase from .model_serializer import ( From b5242a933bb8ae6826516148534990f7f586326f Mon Sep 17 00:00:00 2001 From: SKairinos Date: Tue, 5 Nov 2024 17:58:21 +0000 Subject: [PATCH 17/56] split model list serializer --- codeforlife/serializers/__init__.py | 3 +- codeforlife/serializers/model.py | 170 +---------------------- codeforlife/serializers/model_list.py | 188 ++++++++++++++++++++++++++ 3 files changed, 193 insertions(+), 168 deletions(-) create mode 100644 codeforlife/serializers/model_list.py diff --git a/codeforlife/serializers/__init__.py b/codeforlife/serializers/__init__.py index 44782491..3e8c453c 100644 --- a/codeforlife/serializers/__init__.py +++ b/codeforlife/serializers/__init__.py @@ -4,4 +4,5 @@ """ from .base import BaseSerializer -from .model import ModelListSerializer, ModelSerializer +from .model import ModelSerializer +from .model_list import ModelListSerializer diff --git a/codeforlife/serializers/model.py b/codeforlife/serializers/model.py index 2a1476b0..e2fca15a 100644 --- a/codeforlife/serializers/model.py +++ b/codeforlife/serializers/model.py @@ -8,14 +8,12 @@ import typing as t from django.db.models import Model -from rest_framework.serializers import ListSerializer as _ListSerializer from rest_framework.serializers import ModelSerializer as _ModelSerializer -from rest_framework.serializers import ValidationError as _ValidationError -from ..types import DataDict, OrderedDataDict +from ..types import DataDict from .base import BaseSerializer -# pylint: disable-next=duplicate-code +# pylint: disable=duplicate-code if t.TYPE_CHECKING: from ..user.models import User @@ -24,11 +22,7 @@ RequestUser = t.TypeVar("RequestUser") AnyModel = t.TypeVar("AnyModel", bound=Model) - - -BulkCreateDataList = t.List[DataDict] -BulkUpdateDataDict = t.Dict[t.Any, DataDict] -Data = t.Union[BulkCreateDataList, BulkUpdateDataDict] +# pylint: enable=duplicate-code class ModelSerializer( @@ -67,161 +61,3 @@ def validate(self, attrs: DataDict): # pylint: disable-next=useless-parent-delegation def to_representation(self, instance: AnyModel) -> DataDict: return super().to_representation(instance) - - -class ModelListSerializer( - BaseSerializer[RequestUser], - _ListSerializer[t.List[AnyModel]], - t.Generic[RequestUser, AnyModel], -): - """Base model list serializer for all model list serializers. - - Inherit this class if you wish to custom handle bulk create and/or update. - - class UserListSerializer(ModelListSerializer[User, User]): - def create(self, validated_data): - ... - - def update(self, instance, validated_data): - ... - - class UserSerializer(ModelSerializer[User, User]): - class Meta: - model = User - list_serializer_class = UserListSerializer - """ - - instance: t.Optional[t.List[AnyModel]] - batch_size: t.Optional[int] = None - - @property - def view(self): - # NOTE: import outside top-level to avoid circular imports. - # pylint: disable-next=import-outside-toplevel - from ..views import ModelViewSet - - return t.cast(ModelViewSet[RequestUser, AnyModel], super().view) - - @property - def non_none_instance(self): - """Casts the instance to not None.""" - return t.cast(t.List[AnyModel], self.instance) - - @classmethod - def get_model_class(cls) -> t.Type[AnyModel]: - """Get the model view set's class. - - Returns: - The model view set's class. - """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] - - def __init__(self, *args, **kwargs): - instance = args[0] if args else kwargs.pop("instance", None) - if instance is not None and not isinstance(instance, list): - instance = list(instance) - - super().__init__(instance, *args[1:], **kwargs) - - def create(self, validated_data: t.List[DataDict]) -> t.List[AnyModel]: - """Bulk create many instances of a model. - - https://www.django-rest-framework.org/api-guide/serializers/#customizing-multiple-create - - Args: - validated_data: The data used to create the models. - - Returns: - The models. - """ - model_class = self.get_model_class() - return model_class.objects.bulk_create( # type: ignore[attr-defined] - objs=[model_class(**data) for data in validated_data], - batch_size=self.batch_size, - ) - - def update( - self, - instance: t.List[AnyModel], - validated_data: t.List[DataDict], - ) -> t.List[AnyModel]: - """Bulk update many instances of a model. - - https://www.django-rest-framework.org/api-guide/serializers/#customizing-multiple-update - - Args: - instance: The models to update. - validated_data: The field-value pairs to update for each model. - - Returns: - The models. - """ - # Models and data must have equal length and be ordered the same! - for model, data in zip(instance, validated_data): - for field, value in data.items(): - setattr(model, field, value) - - model_class = self.get_model_class() - model_class.objects.bulk_update( # type: ignore[attr-defined] - objs=instance, - fields={field for data in validated_data for field in data.keys()}, - batch_size=self.batch_size, - ) - - return instance - - def validate(self, attrs: t.List[DataDict]): - # If performing a bulk create. - if self.instance is None: - if len(attrs) == 0: - raise _ValidationError( - "Nothing to create.", - code="nothing_to_create", - ) - - # Else, performing a bulk update. - else: - if len(attrs) == 0: - raise _ValidationError( - "Nothing to update.", - code="nothing_to_update", - ) - if len(attrs) != len(self.instance): - raise _ValidationError( - "Some models do not exist.", - code="models_do_not_exist", - ) - - return attrs - - def to_internal_value(self, data: Data): - # If performing a bulk create. - if self.instance is None: - data = t.cast(BulkCreateDataList, data) - - return t.cast( - t.List[OrderedDataDict], - super().to_internal_value(data), - ) - - # Else, performing a bulk update. - data = t.cast(BulkUpdateDataDict, data) - data_items = list(data.items()) - - # Models and data are required to be sorted by the lookup field. - data_items.sort(key=lambda item: item[0]) - self.instance.sort( - key=lambda model: getattr(model, self.view.lookup_field) - ) - - return t.cast( - t.List[OrderedDataDict], - super().to_internal_value([item[1] for item in data_items]), - ) - - # pylint: disable-next=useless-parent-delegation,arguments-renamed - def to_representation(self, instance: t.List[AnyModel]) -> t.List[DataDict]: - return super().to_representation(instance) diff --git a/codeforlife/serializers/model_list.py b/codeforlife/serializers/model_list.py new file mode 100644 index 00000000..d292c2bd --- /dev/null +++ b/codeforlife/serializers/model_list.py @@ -0,0 +1,188 @@ +""" +© Ocado Group +Created on 05/11/2024 at 17:53:40(+00:00). + +Base model list serializers. +""" + +import typing as t + +from django.db.models import Model +from rest_framework.serializers import ListSerializer as _ListSerializer +from rest_framework.serializers import ValidationError as _ValidationError + +from ..types import DataDict, OrderedDataDict +from .base import BaseSerializer + +# pylint: disable=duplicate-code +if t.TYPE_CHECKING: + from ..user.models import User + + RequestUser = t.TypeVar("RequestUser", bound=User) +else: + RequestUser = t.TypeVar("RequestUser") + +AnyModel = t.TypeVar("AnyModel", bound=Model) +# pylint: enable=duplicate-code + +BulkCreateDataList = t.List[DataDict] +BulkUpdateDataDict = t.Dict[t.Any, DataDict] +Data = t.Union[BulkCreateDataList, BulkUpdateDataDict] + + +class ModelListSerializer( + BaseSerializer[RequestUser], + _ListSerializer[t.List[AnyModel]], + t.Generic[RequestUser, AnyModel], +): + """Base model list serializer for all model list serializers. + + Inherit this class if you wish to custom handle bulk create and/or update. + + class UserListSerializer(ModelListSerializer[User, User]): + def create(self, validated_data): + ... + + def update(self, instance, validated_data): + ... + + class UserSerializer(ModelSerializer[User, User]): + class Meta: + model = User + list_serializer_class = UserListSerializer + """ + + instance: t.Optional[t.List[AnyModel]] + batch_size: t.Optional[int] = None + + @property + def view(self): + # NOTE: import outside top-level to avoid circular imports. + # pylint: disable-next=import-outside-toplevel + from ..views import ModelViewSet + + return t.cast(ModelViewSet[RequestUser, AnyModel], super().view) + + @property + def non_none_instance(self): + """Casts the instance to not None.""" + return t.cast(t.List[AnyModel], self.instance) + + @classmethod + def get_model_class(cls) -> t.Type[AnyModel]: + """Get the model view set's class. + + Returns: + The model view set's class. + """ + # pylint: disable-next=no-member + return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] + 0 + ] + + def __init__(self, *args, **kwargs): + instance = args[0] if args else kwargs.pop("instance", None) + if instance is not None and not isinstance(instance, list): + instance = list(instance) + + super().__init__(instance, *args[1:], **kwargs) + + def create(self, validated_data: t.List[DataDict]) -> t.List[AnyModel]: + """Bulk create many instances of a model. + + https://www.django-rest-framework.org/api-guide/serializers/#customizing-multiple-create + + Args: + validated_data: The data used to create the models. + + Returns: + The models. + """ + model_class = self.get_model_class() + return model_class.objects.bulk_create( # type: ignore[attr-defined] + objs=[model_class(**data) for data in validated_data], + batch_size=self.batch_size, + ) + + def update( + self, + instance: t.List[AnyModel], + validated_data: t.List[DataDict], + ) -> t.List[AnyModel]: + """Bulk update many instances of a model. + + https://www.django-rest-framework.org/api-guide/serializers/#customizing-multiple-update + + Args: + instance: The models to update. + validated_data: The field-value pairs to update for each model. + + Returns: + The models. + """ + # Models and data must have equal length and be ordered the same! + for model, data in zip(instance, validated_data): + for field, value in data.items(): + setattr(model, field, value) + + model_class = self.get_model_class() + model_class.objects.bulk_update( # type: ignore[attr-defined] + objs=instance, + fields={field for data in validated_data for field in data.keys()}, + batch_size=self.batch_size, + ) + + return instance + + def validate(self, attrs: t.List[DataDict]): + # If performing a bulk create. + if self.instance is None: + if len(attrs) == 0: + raise _ValidationError( + "Nothing to create.", + code="nothing_to_create", + ) + + # Else, performing a bulk update. + else: + if len(attrs) == 0: + raise _ValidationError( + "Nothing to update.", + code="nothing_to_update", + ) + if len(attrs) != len(self.instance): + raise _ValidationError( + "Some models do not exist.", + code="models_do_not_exist", + ) + + return attrs + + def to_internal_value(self, data: Data): + # If performing a bulk create. + if self.instance is None: + data = t.cast(BulkCreateDataList, data) + + return t.cast( + t.List[OrderedDataDict], + super().to_internal_value(data), + ) + + # Else, performing a bulk update. + data = t.cast(BulkUpdateDataDict, data) + data_items = list(data.items()) + + # Models and data are required to be sorted by the lookup field. + data_items.sort(key=lambda item: item[0]) + self.instance.sort( + key=lambda model: getattr(model, self.view.lookup_field) + ) + + return t.cast( + t.List[OrderedDataDict], + super().to_internal_value([item[1] for item in data_items]), + ) + + # pylint: disable-next=useless-parent-delegation,arguments-renamed + def to_representation(self, instance: t.List[AnyModel]) -> t.List[DataDict]: + return super().to_representation(instance) From d19d1abab46afc9d6594343f11a3a27d79491596 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Tue, 5 Nov 2024 18:38:33 +0000 Subject: [PATCH 18/56] abstract model view and serializer --- codeforlife/serializers/__init__.py | 2 +- codeforlife/serializers/base.py | 16 ++++----- codeforlife/serializers/model.py | 27 ++++++++------ codeforlife/views/__init__.py | 4 +-- codeforlife/views/api.py | 15 +++++--- codeforlife/views/model.py | 56 ++++++++++++++++++----------- 6 files changed, 71 insertions(+), 49 deletions(-) diff --git a/codeforlife/serializers/__init__.py b/codeforlife/serializers/__init__.py index 3e8c453c..b4299116 100644 --- a/codeforlife/serializers/__init__.py +++ b/codeforlife/serializers/__init__.py @@ -4,5 +4,5 @@ """ from .base import BaseSerializer -from .model import ModelSerializer +from .model import BaseModelSerializer, ModelSerializer from .model_list import ModelListSerializer diff --git a/codeforlife/serializers/base.py b/codeforlife/serializers/base.py index 8f43af79..0ddd962e 100644 --- a/codeforlife/serializers/base.py +++ b/codeforlife/serializers/base.py @@ -10,26 +10,22 @@ from django.views import View from rest_framework.serializers import BaseSerializer as _BaseSerializer -from ..request import Request +from ..request import BaseRequest -# pylint: disable-next=duplicate-code -if t.TYPE_CHECKING: - from ..user.models import User - - RequestUser = t.TypeVar("RequestUser", bound=User) -else: - RequestUser = t.TypeVar("RequestUser") +# pylint: disable=duplicate-code +AnyBaseRequest = t.TypeVar("AnyBaseRequest", bound=BaseRequest) +# pylint: enable=duplicate-code # pylint: disable-next=abstract-method -class BaseSerializer(_BaseSerializer, t.Generic[RequestUser]): +class BaseSerializer(_BaseSerializer, t.Generic[AnyBaseRequest]): """Base serializer to be inherited by all other serializers.""" @property def request(self): """The HTTP request that triggered the view.""" - return t.cast(Request[RequestUser], self.context["request"]) + return t.cast(AnyBaseRequest, self.context["request"]) @property def view(self): diff --git a/codeforlife/serializers/model.py b/codeforlife/serializers/model.py index e2fca15a..7b5601c2 100644 --- a/codeforlife/serializers/model.py +++ b/codeforlife/serializers/model.py @@ -10,37 +10,33 @@ from django.db.models import Model from rest_framework.serializers import ModelSerializer as _ModelSerializer +from ..request import BaseRequest, Request from ..types import DataDict from .base import BaseSerializer # pylint: disable=duplicate-code if t.TYPE_CHECKING: from ..user.models import User + from ..views import BaseModelViewSet, ModelViewSet RequestUser = t.TypeVar("RequestUser", bound=User) else: RequestUser = t.TypeVar("RequestUser") AnyModel = t.TypeVar("AnyModel", bound=Model) +AnyBaseRequest = t.TypeVar("AnyBaseRequest", bound=BaseRequest) # pylint: enable=duplicate-code -class ModelSerializer( - BaseSerializer[RequestUser], +class BaseModelSerializer( + BaseSerializer[AnyBaseRequest], _ModelSerializer[AnyModel], - t.Generic[RequestUser, AnyModel], + t.Generic[AnyBaseRequest, AnyModel], ): """Base model serializer for all model serializers.""" instance: t.Optional[AnyModel] - - @property - def view(self): - # NOTE: import outside top-level to avoid circular imports. - # pylint: disable-next=import-outside-toplevel - from ..views import ModelViewSet - - return t.cast(ModelViewSet[RequestUser, AnyModel], super().view) + view: "BaseModelViewSet[AnyBaseRequest, AnyModel]" @property def non_none_instance(self): @@ -61,3 +57,12 @@ def validate(self, attrs: DataDict): # pylint: disable-next=useless-parent-delegation def to_representation(self, instance: AnyModel) -> DataDict: return super().to_representation(instance) + + +class ModelSerializer( + BaseModelSerializer[Request[RequestUser], AnyModel], + t.Generic[RequestUser, AnyModel], +): + """Base model serializer for all model serializers.""" + + view: "ModelViewSet[RequestUser, AnyModel]" # type: ignore[assignment] diff --git a/codeforlife/views/__init__.py b/codeforlife/views/__init__.py index 1822c7e9..33ecf266 100644 --- a/codeforlife/views/__init__.py +++ b/codeforlife/views/__init__.py @@ -3,7 +3,7 @@ Created on 24/01/2024 at 13:07:38(+00:00). """ -from .api import APIView +from .api import APIView, BaseAPIView from .common import CsrfCookieView, LogoutView from .decorators import action, cron_job -from .model import ModelViewSet +from .model import BaseModelViewSet, ModelViewSet diff --git a/codeforlife/views/api.py b/codeforlife/views/api.py index fe92dfe1..7ec18c2a 100644 --- a/codeforlife/views/api.py +++ b/codeforlife/views/api.py @@ -7,9 +7,9 @@ from rest_framework.views import APIView as _APIView -from ..request import Request +from ..request import BaseRequest, Request -# pylint: disable-next=duplicate-code +# pylint: disable=duplicate-code if t.TYPE_CHECKING: from ..user.models import User @@ -17,11 +17,18 @@ else: RequestUser = t.TypeVar("RequestUser") +AnyBaseRequest = t.TypeVar("AnyBaseRequest", bound=BaseRequest) + +# pylint: enable=duplicate-code + # pylint: disable-next=missing-class-docstring -class APIView(_APIView, t.Generic[RequestUser]): - request: Request[RequestUser] +class BaseAPIView(_APIView, t.Generic[AnyBaseRequest]): + request: AnyBaseRequest + +# pylint: disable-next=missing-class-docstring +class APIView(BaseAPIView[Request[RequestUser]], t.Generic[RequestUser]): @classmethod def get_request_user_class(cls) -> t.Type[RequestUser]: """Get the request's user class. diff --git a/codeforlife/views/model.py b/codeforlife/views/model.py index 3f5f6a39..93ec7f5c 100644 --- a/codeforlife/views/model.py +++ b/codeforlife/views/model.py @@ -14,16 +14,20 @@ from rest_framework.viewsets import ModelViewSet as DrfModelViewSet from ..permissions import Permission -from ..request import Request +from ..request import BaseRequest, Request from ..types import KwArgs -from .api import APIView +from .api import APIView, BaseAPIView from .decorators import action AnyModel = t.TypeVar("AnyModel", bound=Model) -# pylint: disable-next=duplicate-code +# pylint: disable=duplicate-code if t.TYPE_CHECKING: # pragma: no cover - from ..serializers import ModelListSerializer, ModelSerializer + from ..serializers import ( + BaseModelSerializer, + ModelListSerializer, + ModelSerializer, + ) from ..user.models import User RequestUser = t.TypeVar("RequestUser", bound=User) @@ -41,16 +45,21 @@ class _ModelViewSet(DrfModelViewSet, t.Generic[AnyModel]): pass +AnyBaseRequest = t.TypeVar("AnyBaseRequest", bound=BaseRequest) + +# pylint: enable=duplicate-code + + # pylint: disable-next=too-many-ancestors -class ModelViewSet( - APIView[RequestUser], +class BaseModelViewSet( + BaseAPIView[AnyBaseRequest], _ModelViewSet[AnyModel], - t.Generic[RequestUser, AnyModel], + t.Generic[AnyBaseRequest, AnyModel], ): """Base model view set for all model view sets.""" serializer_class: t.Optional[ - t.Type["ModelSerializer[RequestUser, AnyModel]"] + t.Type["BaseModelSerializer[AnyBaseRequest, AnyModel]"] ] @classmethod @@ -114,47 +123,52 @@ class _ModelListSerializer( return serializer - # -------------------------------------------------------------------------- - # View Set Actions - # -------------------------------------------------------------------------- - # pylint: disable=useless-parent-delegation def destroy( # type: ignore[override] # pragma: no cover - self, request: Request[RequestUser], *args, **kwargs + self, request: AnyBaseRequest, *args, **kwargs ): return super().destroy(request, *args, **kwargs) def create( # type: ignore[override] # pragma: no cover - self, request: Request[RequestUser], *args, **kwargs + self, request: AnyBaseRequest, *args, **kwargs ): return super().create(request, *args, **kwargs) def list( # type: ignore[override] # pragma: no cover - self, request: Request[RequestUser], *args, **kwargs + self, request: AnyBaseRequest, *args, **kwargs ): return super().list(request, *args, **kwargs) def retrieve( # type: ignore[override] # pragma: no cover - self, request: Request[RequestUser], *args, **kwargs + self, request: AnyBaseRequest, *args, **kwargs ): return super().retrieve(request, *args, **kwargs) def update( # type: ignore[override] # pragma: no cover - self, request: Request[RequestUser], *args, **kwargs + self, request: AnyBaseRequest, *args, **kwargs ): return super().update(request, *args, **kwargs) def partial_update( # type: ignore[override] # pragma: no cover - self, request: Request[RequestUser], *args, **kwargs + self, request: AnyBaseRequest, *args, **kwargs ): return super().partial_update(request, *args, **kwargs) # pylint: enable=useless-parent-delegation - # -------------------------------------------------------------------------- - # Bulk Actions - # -------------------------------------------------------------------------- + +# pylint: disable-next=too-many-ancestors +class ModelViewSet( + BaseModelViewSet[Request[RequestUser], AnyModel], + APIView[RequestUser], + t.Generic[RequestUser, AnyModel], +): + """Base model view set for all model view sets.""" + + serializer_class: t.Optional[ + t.Type["ModelSerializer[RequestUser, AnyModel]"] + ] def get_bulk_queryset(self, lookup_values: t.Collection): """Get the queryset for a bulk action. From 0f695c9f0d100af5069dea80526ebdb815245f4d Mon Sep 17 00:00:00 2001 From: SKairinos Date: Tue, 5 Nov 2024 18:46:18 +0000 Subject: [PATCH 19/56] abstract model list --- codeforlife/serializers/model_list.py | 42 ++++++++++++++++++++------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/codeforlife/serializers/model_list.py b/codeforlife/serializers/model_list.py index d292c2bd..18249adb 100644 --- a/codeforlife/serializers/model_list.py +++ b/codeforlife/serializers/model_list.py @@ -11,18 +11,21 @@ from rest_framework.serializers import ListSerializer as _ListSerializer from rest_framework.serializers import ValidationError as _ValidationError +from ..request import BaseRequest, Request from ..types import DataDict, OrderedDataDict from .base import BaseSerializer # pylint: disable=duplicate-code if t.TYPE_CHECKING: from ..user.models import User + from ..views import BaseModelViewSet, ModelViewSet RequestUser = t.TypeVar("RequestUser", bound=User) else: RequestUser = t.TypeVar("RequestUser") AnyModel = t.TypeVar("AnyModel", bound=Model) +AnyBaseRequest = t.TypeVar("AnyBaseRequest", bound=BaseRequest) # pylint: enable=duplicate-code BulkCreateDataList = t.List[DataDict] @@ -30,10 +33,10 @@ Data = t.Union[BulkCreateDataList, BulkUpdateDataDict] -class ModelListSerializer( - BaseSerializer[RequestUser], +class BaseModelListSerializer( + BaseSerializer[AnyBaseRequest], _ListSerializer[t.List[AnyModel]], - t.Generic[RequestUser, AnyModel], + t.Generic[AnyBaseRequest, AnyModel], ): """Base model list serializer for all model list serializers. @@ -54,14 +57,7 @@ class Meta: instance: t.Optional[t.List[AnyModel]] batch_size: t.Optional[int] = None - - @property - def view(self): - # NOTE: import outside top-level to avoid circular imports. - # pylint: disable-next=import-outside-toplevel - from ..views import ModelViewSet - - return t.cast(ModelViewSet[RequestUser, AnyModel], super().view) + view: "BaseModelViewSet[AnyBaseRequest, AnyModel]" @property def non_none_instance(self): @@ -186,3 +182,27 @@ def to_internal_value(self, data: Data): # pylint: disable-next=useless-parent-delegation,arguments-renamed def to_representation(self, instance: t.List[AnyModel]) -> t.List[DataDict]: return super().to_representation(instance) + + +class ModelListSerializer( + BaseModelListSerializer[Request[RequestUser], AnyModel], + t.Generic[RequestUser, AnyModel], +): + """Base model list serializer for all model list serializers. + + Inherit this class if you wish to custom handle bulk create and/or update. + + class UserListSerializer(ModelListSerializer[User, User]): + def create(self, validated_data): + ... + + def update(self, instance, validated_data): + ... + + class UserSerializer(ModelSerializer[User, User]): + class Meta: + model = User + list_serializer_class = UserListSerializer + """ + + view: "ModelViewSet[RequestUser, AnyModel]" # type: ignore[assignment] From f9294cd5f6a8b76874949a0e9bf2c521cd4733bb Mon Sep 17 00:00:00 2001 From: SKairinos Date: Tue, 5 Nov 2024 18:46:46 +0000 Subject: [PATCH 20/56] import BaseModelListSerializer --- codeforlife/serializers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeforlife/serializers/__init__.py b/codeforlife/serializers/__init__.py index b4299116..e9cec59e 100644 --- a/codeforlife/serializers/__init__.py +++ b/codeforlife/serializers/__init__.py @@ -5,4 +5,4 @@ from .base import BaseSerializer from .model import BaseModelSerializer, ModelSerializer -from .model_list import ModelListSerializer +from .model_list import BaseModelListSerializer, ModelListSerializer From 24928702aa4feec127b86a2ab566c4b8c9ab7da7 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Tue, 5 Nov 2024 18:55:14 +0000 Subject: [PATCH 21/56] disable missing-function-docstring --- codeforlife/user/views/klass.py | 1 + codeforlife/user/views/user.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/codeforlife/user/views/klass.py b/codeforlife/user/views/klass.py index 3da9c783..2a4c7b07 100644 --- a/codeforlife/user/views/klass.py +++ b/codeforlife/user/views/klass.py @@ -19,6 +19,7 @@ class ClassViewSet(ModelViewSet[RequestUser, Class]): serializer_class = ClassSerializer filterset_class = ClassFilterSet + # pylint: disable-next=missing-function-docstring def get_permissions(self): # Only school-teachers can list classes. if self.action == "list": diff --git a/codeforlife/user/views/user.py b/codeforlife/user/views/user.py index 90aa9c8c..6b9e08cc 100644 --- a/codeforlife/user/views/user.py +++ b/codeforlife/user/views/user.py @@ -21,6 +21,7 @@ class UserViewSet(ModelViewSet[RequestUser, User]): serializer_class = UserSerializer[User] filterset_class = UserFilterSet + # pylint: disable-next=missing-function-docstring def get_queryset( self, user_class: t.Type[AnyUser] = User, # type: ignore[assignment] @@ -67,6 +68,7 @@ def get_queryset( return queryset.filter(pk=user.pk) + # pylint: disable-next=missing-function-docstring def get_bulk_queryset( # pragma: no cover self, lookup_values: t.Collection, From 77925a7c737988923077c74c38c4d649198b7a90 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Wed, 6 Nov 2024 10:42:50 +0000 Subject: [PATCH 22/56] init request --- codeforlife/views/api.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/codeforlife/views/api.py b/codeforlife/views/api.py index 7ec18c2a..facbde6d 100644 --- a/codeforlife/views/api.py +++ b/codeforlife/views/api.py @@ -25,10 +25,25 @@ # pylint: disable-next=missing-class-docstring class BaseAPIView(_APIView, t.Generic[AnyBaseRequest]): request: AnyBaseRequest + request_class: t.Type[AnyBaseRequest] + + def initialize_request(self, request, *args, **kwargs): + # NOTE: Call to super has side effects and is required. + super().initialize_request(request, *args, **kwargs) + + return self.request_class( + request=request, + parsers=self.get_parsers(), + authenticators=self.get_authenticators(), + negotiator=self.get_content_negotiator(), + parser_context=self.get_parser_context(request), + ) # pylint: disable-next=missing-class-docstring class APIView(BaseAPIView[Request[RequestUser]], t.Generic[RequestUser]): + request_class = Request + @classmethod def get_request_user_class(cls) -> t.Type[RequestUser]: """Get the request's user class. @@ -45,7 +60,7 @@ def initialize_request(self, request, *args, **kwargs): # NOTE: Call to super has side effects and is required. super().initialize_request(request, *args, **kwargs) - return Request( + return self.request_class( user_class=self.get_request_user_class(), request=request, parsers=self.get_parsers(), From b805632573a12224713fc504e290b9475ef2073b Mon Sep 17 00:00:00 2001 From: SKairinos Date: Wed, 6 Nov 2024 12:14:07 +0000 Subject: [PATCH 23/56] fix: init request --- codeforlife/views/api.py | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/codeforlife/views/api.py b/codeforlife/views/api.py index facbde6d..ff143590 100644 --- a/codeforlife/views/api.py +++ b/codeforlife/views/api.py @@ -5,6 +5,7 @@ import typing as t +from django.http import HttpRequest from rest_framework.views import APIView as _APIView from ..request import BaseRequest, Request @@ -27,17 +28,20 @@ class BaseAPIView(_APIView, t.Generic[AnyBaseRequest]): request: AnyBaseRequest request_class: t.Type[AnyBaseRequest] + def _initialize_request(self, request: HttpRequest, **kwargs): + kwargs["request"] = request + kwargs.setdefault("parsers", self.get_parsers()) + kwargs.setdefault("authenticators", self.get_authenticators()) + kwargs.setdefault("negotiator", self.get_content_negotiator()) + kwargs.setdefault("parser_context", self.get_parser_context(request)) + + return self.request_class(**kwargs) + def initialize_request(self, request, *args, **kwargs): # NOTE: Call to super has side effects and is required. super().initialize_request(request, *args, **kwargs) - return self.request_class( - request=request, - parsers=self.get_parsers(), - authenticators=self.get_authenticators(), - negotiator=self.get_content_negotiator(), - parser_context=self.get_parser_context(request), - ) + return self._initialize_request(request) # pylint: disable-next=missing-class-docstring @@ -56,15 +60,6 @@ def get_request_user_class(cls) -> t.Type[RequestUser]: 0 ] - def initialize_request(self, request, *args, **kwargs): - # NOTE: Call to super has side effects and is required. - super().initialize_request(request, *args, **kwargs) - - return self.request_class( - user_class=self.get_request_user_class(), - request=request, - parsers=self.get_parsers(), - authenticators=self.get_authenticators(), - negotiator=self.get_content_negotiator(), - parser_context=self.get_parser_context(request), - ) + def _initialize_request(self, request, **kwargs): + kwargs["user_class"] = self.get_request_user_class() + return super()._initialize_request(request, **kwargs) From 27df3eeebbc3e486d49276906c2efa4c7702fc3b Mon Sep 17 00:00:00 2001 From: SKairinos Date: Wed, 6 Nov 2024 13:12:52 +0000 Subject: [PATCH 24/56] abstract model serializer test case --- codeforlife/tests/__init__.py | 6 +- codeforlife/tests/model_list_serializer.py | 88 ++++++++++++++++ codeforlife/tests/model_serializer.py | 117 +++++++++++---------- 3 files changed, 157 insertions(+), 54 deletions(-) create mode 100644 codeforlife/tests/model_list_serializer.py diff --git a/codeforlife/tests/__init__.py b/codeforlife/tests/__init__.py index 9943912c..a2135dd9 100644 --- a/codeforlife/tests/__init__.py +++ b/codeforlife/tests/__init__.py @@ -9,8 +9,12 @@ from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory from .cron import CronTestCase from .model import ModelTestCase -from .model_serializer import ( +from .model_list_serializer import ( + BaseModelListSerializerTestCase, ModelListSerializerTestCase, +) +from .model_serializer import ( + BaseModelSerializerTestCase, ModelSerializerTestCase, ) from .model_view_set import ModelViewSetClient, ModelViewSetTestCase diff --git a/codeforlife/tests/model_list_serializer.py b/codeforlife/tests/model_list_serializer.py new file mode 100644 index 00000000..f25da034 --- /dev/null +++ b/codeforlife/tests/model_list_serializer.py @@ -0,0 +1,88 @@ +""" +© Ocado Group +Created on 06/11/2024 at 12:45:33(+00:00). + +Base test case for all model list serializers. +""" + +import typing as t + +from django.db.models import Model + +from ..serializers import ( + BaseModelListSerializer, + BaseModelSerializer, + ModelListSerializer, + ModelSerializer, +) +from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory +from .model_serializer import ( + BaseModelSerializerTestCase, + ModelSerializerTestCase, +) + +# pylint: disable=duplicate-code +if t.TYPE_CHECKING: + from ..user.models import User + + RequestUser = t.TypeVar("RequestUser", bound=User) +else: + RequestUser = t.TypeVar("RequestUser") + +AnyModel = t.TypeVar("AnyModel", bound=Model) +AnyBaseModelSerializer = t.TypeVar( + "AnyBaseModelSerializer", bound=BaseModelSerializer +) +AnyBaseModelListSerializer = t.TypeVar( + "AnyBaseModelListSerializer", bound=BaseModelListSerializer +) +AnyBaseAPIRequestFactory = t.TypeVar( + "AnyBaseAPIRequestFactory", bound=BaseAPIRequestFactory +) +# pylint: enable=duplicate-code + + +class BaseModelListSerializerTestCase( + BaseModelSerializerTestCase[ + AnyBaseModelSerializer, + AnyBaseAPIRequestFactory, + AnyModel, + ], + t.Generic[ + AnyBaseModelListSerializer, + AnyBaseModelSerializer, + AnyBaseAPIRequestFactory, + AnyModel, + ], +): + """Base for all model serializer test cases.""" + + model_list_serializer_class: t.Type[AnyBaseModelListSerializer] + + REQUIRED_ATTRS = { + "model_list_serializer_class", + "model_serializer_class", + "request_factory_class", + } + + def _init_model_serializer(self, *args, parent=None, **kwargs): + kwargs.setdefault("child", self.model_serializer_class()) + serializer = self.model_list_serializer_class(*args, **kwargs) + if parent: + serializer.parent = parent + + return serializer + + +# pylint: disable-next=too-many-ancestors +class ModelListSerializerTestCase( + BaseModelListSerializerTestCase[ + ModelListSerializer[RequestUser, AnyModel], + ModelSerializer[RequestUser, AnyModel], + APIRequestFactory[RequestUser], + AnyModel, + ], + ModelSerializerTestCase[RequestUser, AnyModel], + t.Generic[RequestUser, AnyModel], +): + """Base for all model serializer test cases.""" diff --git a/codeforlife/tests/model_serializer.py b/codeforlife/tests/model_serializer.py index dcda7bba..8f01541a 100644 --- a/codeforlife/tests/model_serializer.py +++ b/codeforlife/tests/model_serializer.py @@ -13,12 +13,16 @@ from django.forms.models import model_to_dict from rest_framework.serializers import BaseSerializer, ValidationError -from ..serializers import ModelListSerializer, ModelSerializer +from ..serializers import ( + BaseModelSerializer, + ModelListSerializer, + ModelSerializer, +) from ..types import DataDict -from .api_request_factory import APIRequestFactory +from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory from .test import TestCase -# pylint: disable-next=duplicate-code +# pylint: disable=duplicate-code if t.TYPE_CHECKING: from ..user.models import User @@ -26,50 +30,45 @@ else: RequestUser = t.TypeVar("RequestUser") - AnyModel = t.TypeVar("AnyModel", bound=Model) +AnyBaseModelSerializer = t.TypeVar( + "AnyBaseModelSerializer", bound=BaseModelSerializer +) +AnyBaseAPIRequestFactory = t.TypeVar( + "AnyBaseAPIRequestFactory", bound=BaseAPIRequestFactory +) +# pylint: enable=duplicate-code + + +class BaseModelSerializerTestCase( + TestCase, + t.Generic[AnyBaseModelSerializer, AnyBaseAPIRequestFactory, AnyModel], +): + """Base for all model serializer test cases.""" + model_serializer_class: t.Type[AnyBaseModelSerializer] -class ModelSerializerTestCase(TestCase, t.Generic[RequestUser, AnyModel]): - """Base for all model serializer test cases.""" + request_factory: AnyBaseAPIRequestFactory + request_factory_class: t.Type[AnyBaseAPIRequestFactory] - model_serializer_class: t.Type[ModelSerializer[RequestUser, AnyModel]] + REQUIRED_ATTRS: t.Set[str] = { + "model_serializer_class", + "request_factory_class", + } - request_factory: APIRequestFactory[RequestUser] + @classmethod + def _initialize_request_factory(cls, **kwargs): + return cls.request_factory_class(**kwargs) @classmethod def setUpClass(cls): - attr_name = "model_serializer_class" - assert hasattr(cls, attr_name), f'Attribute "{attr_name}" must be set.' + for attr in cls.REQUIRED_ATTRS: + assert hasattr(cls, attr), f'Attribute "{attr}" must be set.' - cls.request_factory = APIRequestFactory(cls.get_request_user_class()) + cls.request_factory = cls._initialize_request_factory() return super().setUpClass() - @classmethod - def get_request_user_class(cls) -> t.Type[AnyModel]: - """Get the model view set's class. - - Returns: - The model view set's class. - """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] - - @classmethod - def get_model_class(cls) -> t.Type[AnyModel]: - """Get the model view set's class. - - Returns: - The model view set's class. - """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 1 - ] - # -------------------------------------------------------------------------- # Private helpers. # -------------------------------------------------------------------------- @@ -379,31 +378,43 @@ def assert_new_data_is_subset_of_data(new_data: DataDict, data): self._assert_data_is_subset_of_model(data, instance) -class ModelListSerializerTestCase( - ModelSerializerTestCase[RequestUser, AnyModel], +class ModelSerializerTestCase( + BaseModelSerializerTestCase[ + ModelSerializer[RequestUser, AnyModel], + APIRequestFactory[RequestUser], + AnyModel, + ], t.Generic[RequestUser, AnyModel], ): """Base for all model serializer test cases.""" - model_list_serializer_class: t.Type[ - ModelListSerializer[RequestUser, AnyModel] - ] + request_factory_class = APIRequestFactory @classmethod - def setUpClass(cls): - attr_name = "model_list_serializer_class" - assert hasattr(cls, attr_name), f'Attribute "{attr_name}" must be set.' + def get_request_user_class(cls) -> t.Type[AnyModel]: + """Get the model view set's class. - return super().setUpClass() + Returns: + The model view set's class. + """ + # pylint: disable-next=no-member + return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] + 0 + ] - # -------------------------------------------------------------------------- - # Private helpers. - # -------------------------------------------------------------------------- + @classmethod + def get_model_class(cls) -> t.Type[AnyModel]: + """Get the model view set's class. - def _init_model_serializer(self, *args, parent=None, **kwargs): - kwargs.setdefault("child", self.model_serializer_class()) - serializer = self.model_list_serializer_class(*args, **kwargs) - if parent: - serializer.parent = parent + Returns: + The model view set's class. + """ + # pylint: disable-next=no-member + return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] + 1 + ] - return serializer + @classmethod + def _initialize_request_factory(cls, **kwargs): + kwargs["user_class"] = cls.get_request_user_class() + return super()._initialize_request_factory(**kwargs) From 48c9b4661b8ce74467b13447227cbdbd29612b9c Mon Sep 17 00:00:00 2001 From: SKairinos Date: Wed, 6 Nov 2024 13:27:43 +0000 Subject: [PATCH 25/56] fix types --- codeforlife/tests/model_serializer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/codeforlife/tests/model_serializer.py b/codeforlife/tests/model_serializer.py index 8f01541a..984652a5 100644 --- a/codeforlife/tests/model_serializer.py +++ b/codeforlife/tests/model_serializer.py @@ -14,8 +14,8 @@ from rest_framework.serializers import BaseSerializer, ValidationError from ..serializers import ( + BaseModelListSerializer, BaseModelSerializer, - ModelListSerializer, ModelSerializer, ) from ..types import DataDict @@ -141,7 +141,7 @@ def _assert_many( new_data: t.Optional[t.List[DataDict]], non_model_fields: t.Optional[NonModelFields], get_models: t.Callable[ - [ModelListSerializer[RequestUser, AnyModel], t.List[DataDict]], + [BaseModelListSerializer[t.Any, AnyModel], t.List[DataDict]], t.List[AnyModel], ], *args, @@ -154,7 +154,7 @@ def _assert_many( assert len(new_data) == len(validated_data) kwargs.pop("many", None) # many must be True - serializer: ModelListSerializer[RequestUser, AnyModel] = ( + serializer: BaseModelListSerializer[t.Any, AnyModel] = ( self._init_model_serializer(*args, **kwargs, many=True) ) From 19cbfdacabd66e8c3f8ebad3f50d228aef58cdb8 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Wed, 6 Nov 2024 13:31:06 +0000 Subject: [PATCH 26/56] fix linting issues --- codeforlife/user/serializers/user_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeforlife/user/serializers/user_test.py b/codeforlife/user/serializers/user_test.py index 415a1462..b07240a1 100644 --- a/codeforlife/user/serializers/user_test.py +++ b/codeforlife/user/serializers/user_test.py @@ -8,7 +8,7 @@ from .user import UserSerializer -# pylint: disable-next=missing-class-docstring +# pylint: disable-next=missing-class-docstring,too-many-ancestors class TestUserSerializer(ModelSerializerTestCase[User, User]): model_serializer_class = UserSerializer From 5a85de9d71ca1417ed616ab34c88bd85d89ae3bf Mon Sep 17 00:00:00 2001 From: SKairinos Date: Wed, 6 Nov 2024 13:40:20 +0000 Subject: [PATCH 27/56] split code --- codeforlife/tests/__init__.py | 3 +- codeforlife/tests/api.py | 507 +------------------------------ codeforlife/tests/api_client.py | 518 ++++++++++++++++++++++++++++++++ 3 files changed, 524 insertions(+), 504 deletions(-) create mode 100644 codeforlife/tests/api_client.py diff --git a/codeforlife/tests/__init__.py b/codeforlife/tests/__init__.py index a2135dd9..38f4ee00 100644 --- a/codeforlife/tests/__init__.py +++ b/codeforlife/tests/__init__.py @@ -5,7 +5,8 @@ Custom test cases. """ -from .api import APIClient, APITestCase +from .api import APITestCase +from .api_client import APIClient from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory from .cron import CronTestCase from .model import ModelTestCase diff --git a/codeforlife/tests/api.py b/codeforlife/tests/api.py index e70fb1c3..58aec421 100644 --- a/codeforlife/tests/api.py +++ b/codeforlife/tests/api.py @@ -3,518 +3,19 @@ Created on 23/02/2024 at 08:46:27(+00:00). """ -import json import typing as t -from unittest.mock import patch -from django.utils import timezone -from rest_framework import status -from rest_framework.response import Response -from rest_framework.test import APIClient as _APIClient - -from ..types import DataDict, JsonDict -from .api_request_factory import APIRequestFactory +from .api_client import APIClient from .test import TestCase -# pylint: disable-next=duplicate-code +# pylint: disable=duplicate-code if t.TYPE_CHECKING: - from ..user.models import TypedUser, User + from ..user.models import User RequestUser = t.TypeVar("RequestUser", bound=User) - LoginUser = t.TypeVar("LoginUser", bound=User) else: RequestUser = t.TypeVar("RequestUser") - LoginUser = t.TypeVar("LoginUser") - - -class APIClient(_APIClient, t.Generic[RequestUser]): - """Base API client to be inherited by all other API clients.""" - - _test_case: "APITestCase[RequestUser]" - - def __init__( - self, - enforce_csrf_checks: bool = False, - raise_request_exception=False, - **defaults, - ): - super().__init__( - enforce_csrf_checks, - raise_request_exception=raise_request_exception, - **defaults, - ) - - self.request_factory = APIRequestFactory( - self.get_request_user_class(), - enforce_csrf_checks, - **defaults, - ) - - @classmethod - def get_request_user_class(cls) -> t.Type[RequestUser]: - """Get the request's user class. - - Returns: - The request's user class. - """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] - - @staticmethod - def status_code_is_ok(status_code: int): - """Check if the status code is greater than or equal to 200 and less - than 300. - - Args: - status_code: The status code to check. - - Returns: - A flag designating if the status code is OK. - """ - return 200 <= status_code < 300 - - # -------------------------------------------------------------------------- - # Assert Response Helpers - # -------------------------------------------------------------------------- - - def _assert_response(self, response: Response, make_assertions: t.Callable): - if self.status_code_is_ok(response.status_code): - make_assertions() - - def _assert_response_json( - self, - response: Response, - make_assertions: t.Callable[[JsonDict], None], - ): - self._assert_response( - response, - make_assertions=lambda: make_assertions( - response.json(), # type: ignore[attr-defined] - ), - ) - - def _assert_response_json_bulk( - self, - response: Response, - make_assertions: t.Callable[[t.List[JsonDict]], None], - data: t.List[DataDict], - ): - def _make_assertions(): - response_json = response.json() # type: ignore[attr-defined] - assert isinstance(response_json, list) - assert len(response_json) == len(data) - make_assertions(response_json) - - self._assert_response(response, _make_assertions) - - # -------------------------------------------------------------------------- - # Login Helpers - # -------------------------------------------------------------------------- - - def _login_user_type(self, user_type: t.Type[LoginUser], **credentials): - # pylint: disable-next=import-outside-toplevel - from ..user.models import AuthFactor - - # Logout current user (if any) before logging in next user. - self.logout() - assert super().login( - **credentials - ), f"Failed to login with credentials: {credentials}." - - user = user_type.objects.get(session=self.session.session_key) - - if user.session.auth_factors.filter( - auth_factor__type=AuthFactor.Type.OTP - ).exists(): - now = timezone.now() - otp = user.totp.at(now) - with patch.object(timezone, "now", return_value=now): - assert super().login( - request=self.request_factory.post( - user=t.cast(RequestUser, user) - ), - otp=otp, - ), f'Failed to login with OTP "{otp}" at {now}.' - - assert user.is_authenticated, "Failed to authenticate user." - - return user - - def login(self, **credentials): - """Log in a user. - - Returns: - The user. - """ - return self._login_user_type(User, **credentials) - - def login_teacher(self, email: str, password: str = "password"): - """Log in a user and assert they are a teacher. - - Args: - email: The user's email address. - password: The user's password. - - Returns: - The teacher-user. - """ - # pylint: disable-next=import-outside-toplevel - from ..user.models import TeacherUser - - return self._login_user_type( - TeacherUser, email=email, password=password - ) - - def login_school_teacher(self, email: str, password: str = "password"): - """Log in a user and assert they are a school-teacher. - - Args: - email: The user's email address. - password: The user's password. - - Returns: - The school-teacher-user. - """ - # pylint: disable-next=import-outside-toplevel - from ..user.models import SchoolTeacherUser - - return self._login_user_type( - SchoolTeacherUser, email=email, password=password - ) - - def login_admin_school_teacher( - self, email: str, password: str = "password" - ): - """Log in a user and assert they are an admin-school-teacher. - - Args: - email: The user's email address. - password: The user's password. - - Returns: - The admin-school-teacher-user. - """ - # pylint: disable-next=import-outside-toplevel - from ..user.models import AdminSchoolTeacherUser - - return self._login_user_type( - AdminSchoolTeacherUser, email=email, password=password - ) - - def login_non_admin_school_teacher( - self, email: str, password: str = "password" - ): - """Log in a user and assert they are a non-admin-school-teacher. - - Args: - email: The user's email address. - password: The user's password. - - Returns: - The non-admin-school-teacher-user. - """ - # pylint: disable-next=import-outside-toplevel - from ..user.models import NonAdminSchoolTeacherUser - - return self._login_user_type( - NonAdminSchoolTeacherUser, email=email, password=password - ) - - def login_non_school_teacher(self, email: str, password: str = "password"): - """Log in a user and assert they are a non-school-teacher. - - Args: - email: The user's email address. - password: The user's password. - - Returns: - The non-school-teacher-user. - """ - # pylint: disable-next=import-outside-toplevel - from ..user.models import NonSchoolTeacherUser - - return self._login_user_type( - NonSchoolTeacherUser, email=email, password=password - ) - - def login_student( - self, class_id: str, first_name: str, password: str = "password" - ): - """Log in a user and assert they are a student. - - Args: - class_id: The ID of the class the student belongs to. - first_name: The user's first name. - password: The user's password. - - Returns: - The student-user. - """ - # pylint: disable-next=import-outside-toplevel - from ..user.models import StudentUser - - return self._login_user_type( - StudentUser, - first_name=first_name, - password=password, - class_id=class_id, - ) - - def login_indy(self, email: str, password: str = "password"): - """Log in a user and assert they are an independent. - - Args: - email: The user's email address. - password: The user's password. - - Returns: - The independent-user. - """ - # pylint: disable-next=import-outside-toplevel - from ..user.models import IndependentUser - - return self._login_user_type( - IndependentUser, email=email, password=password - ) - - def login_as(self, user: "TypedUser", password: str = "password"): - """Log in as a user. The user instance needs to be a user proxy in order - to know which credentials are required. - - Args: - user: The user to log in as. - password: The user's password. - """ - # pylint: disable-next=import-outside-toplevel - from ..user.models import ( - AdminSchoolTeacherUser, - IndependentUser, - NonAdminSchoolTeacherUser, - NonSchoolTeacherUser, - SchoolTeacherUser, - StudentUser, - TeacherUser, - ) - - auth_user = None - - if isinstance(user, AdminSchoolTeacherUser): - auth_user = self.login_admin_school_teacher(user.email, password) - elif isinstance(user, NonAdminSchoolTeacherUser): - auth_user = self.login_non_admin_school_teacher( - user.email, password - ) - elif isinstance(user, SchoolTeacherUser): - auth_user = self.login_school_teacher(user.email, password) - elif isinstance(user, NonSchoolTeacherUser): - auth_user = self.login_non_school_teacher(user.email, password) - elif isinstance(user, TeacherUser): - auth_user = self.login_teacher(user.email, password) - elif isinstance(user, StudentUser): - auth_user = self.login_student( - user.student.class_field.access_code, - user.first_name, - password, - ) - elif isinstance(user, IndependentUser): - auth_user = self.login_indy(user.email, password) - - assert user == auth_user - - # -------------------------------------------------------------------------- - # Request Helpers - # -------------------------------------------------------------------------- - - StatusCodeAssertion = t.Optional[t.Union[int, t.Callable[[int], bool]]] - - # pylint: disable=too-many-arguments,redefined-builtin - - def generic( - self, - method, - path, - data="", - content_type="application/json", - secure=False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - response = t.cast( - Response, - super().generic( - method, - path, - data, - content_type, - secure, - **extra, - ), - ) - - # Use a custom kwarg to handle the common case of checking the - # response's status code. - if status_code_assertion is None: - status_code_assertion = self.status_code_is_ok - elif isinstance(status_code_assertion, int): - expected_status_code = status_code_assertion - status_code_assertion = ( - # pylint: disable-next=unnecessary-lambda-assignment - lambda status_code: status_code - == expected_status_code - ) - - # pylint: disable-next=no-member - status_code = response.status_code - assert status_code_assertion( - status_code - ), f"Unexpected status code: {status_code}." + ( - "\nValidation errors: " - + json.dumps( - # pylint: disable-next=no-member - response.json(), # type: ignore[attr-defined] - indent=2, - default=str, - ) - if status_code == status.HTTP_400_BAD_REQUEST - else "" - ) - - return response - - def get( # type: ignore[override] - self, - path: str, - data: t.Any = None, - follow: bool = False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - return super().get( - path=path, - data=data, - follow=follow, - status_code_assertion=status_code_assertion, - **extra, - ) - - def post( # type: ignore[override] - self, - path: str, - data: t.Any = None, - format: t.Optional[str] = None, - content_type: t.Optional[str] = None, - follow: bool = False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - if format is None and content_type is None: - format = "json" - - return super().post( - path=path, - data=data, - format=format, - content_type=content_type, - follow=follow, - status_code_assertion=status_code_assertion, - **extra, - ) - - def put( # type: ignore[override] - self, - path: str, - data: t.Any = None, - format: t.Optional[str] = None, - content_type: t.Optional[str] = None, - follow: bool = False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - if format is None and content_type is None: - format = "json" - - return super().put( - path=path, - data=data, - format=format, - content_type=content_type, - follow=follow, - status_code_assertion=status_code_assertion, - **extra, - ) - - def patch( # type: ignore[override] - self, - path: str, - data: t.Any = None, - format: t.Optional[str] = None, - content_type: t.Optional[str] = None, - follow: bool = False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - if format is None and content_type is None: - format = "json" - - return super().patch( - path=path, - data=data, - format=format, - content_type=content_type, - follow=follow, - status_code_assertion=status_code_assertion, - **extra, - ) - - def delete( # type: ignore[override] - self, - path: str, - data: t.Any = None, - format: t.Optional[str] = None, - content_type: t.Optional[str] = None, - follow: bool = False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - if format is None and content_type is None: - format = "json" - - return super().delete( - path=path, - data=data, - format=format, - content_type=content_type, - follow=follow, - status_code_assertion=status_code_assertion, - **extra, - ) - - def options( # type: ignore[override] - self, - path: str, - data: t.Any = None, - format: t.Optional[str] = None, - content_type: t.Optional[str] = None, - follow: bool = False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - if format is None and content_type is None: - format = "json" - - return super().options( - path=path, - data=data, - format=format, - content_type=content_type, - follow=follow, - status_code_assertion=status_code_assertion, - **extra, - ) - - # pylint: enable=too-many-arguments,redefined-builtin +# pylint: enable=duplicate-code class APITestCase(TestCase, t.Generic[RequestUser]): diff --git a/codeforlife/tests/api_client.py b/codeforlife/tests/api_client.py new file mode 100644 index 00000000..31aa4c4e --- /dev/null +++ b/codeforlife/tests/api_client.py @@ -0,0 +1,518 @@ +""" +© Ocado Group +Created on 06/11/2024 at 13:35:13(+00:00). +""" + +import json +import typing as t +from unittest.mock import patch + +from django.utils import timezone +from rest_framework import status +from rest_framework.response import Response +from rest_framework.test import APIClient as _APIClient + +from ..types import DataDict, JsonDict +from .api_request_factory import APIRequestFactory + +# pylint: disable=duplicate-code +if t.TYPE_CHECKING: + from ..user.models import TypedUser, User + from .api import APITestCase + + RequestUser = t.TypeVar("RequestUser", bound=User) + LoginUser = t.TypeVar("LoginUser", bound=User) +else: + RequestUser = t.TypeVar("RequestUser") + LoginUser = t.TypeVar("LoginUser") +# pylint: enable=duplicate-code + + +class APIClient(_APIClient, t.Generic[RequestUser]): + """Base API client to be inherited by all other API clients.""" + + _test_case: "APITestCase[RequestUser]" + + def __init__( + self, + enforce_csrf_checks: bool = False, + raise_request_exception=False, + **defaults, + ): + super().__init__( + enforce_csrf_checks, + raise_request_exception=raise_request_exception, + **defaults, + ) + + self.request_factory = APIRequestFactory( + self.get_request_user_class(), + enforce_csrf_checks, + **defaults, + ) + + @classmethod + def get_request_user_class(cls) -> t.Type[RequestUser]: + """Get the request's user class. + + Returns: + The request's user class. + """ + # pylint: disable-next=no-member + return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] + 0 + ] + + @staticmethod + def status_code_is_ok(status_code: int): + """Check if the status code is greater than or equal to 200 and less + than 300. + + Args: + status_code: The status code to check. + + Returns: + A flag designating if the status code is OK. + """ + return 200 <= status_code < 300 + + # -------------------------------------------------------------------------- + # Assert Response Helpers + # -------------------------------------------------------------------------- + + def _assert_response(self, response: Response, make_assertions: t.Callable): + if self.status_code_is_ok(response.status_code): + make_assertions() + + def _assert_response_json( + self, + response: Response, + make_assertions: t.Callable[[JsonDict], None], + ): + self._assert_response( + response, + make_assertions=lambda: make_assertions( + response.json(), # type: ignore[attr-defined] + ), + ) + + def _assert_response_json_bulk( + self, + response: Response, + make_assertions: t.Callable[[t.List[JsonDict]], None], + data: t.List[DataDict], + ): + def _make_assertions(): + response_json = response.json() # type: ignore[attr-defined] + assert isinstance(response_json, list) + assert len(response_json) == len(data) + make_assertions(response_json) + + self._assert_response(response, _make_assertions) + + # -------------------------------------------------------------------------- + # Login Helpers + # -------------------------------------------------------------------------- + + def _login_user_type(self, user_type: t.Type[LoginUser], **credentials): + # pylint: disable-next=import-outside-toplevel + from ..user.models import AuthFactor + + # Logout current user (if any) before logging in next user. + self.logout() + assert super().login( + **credentials + ), f"Failed to login with credentials: {credentials}." + + user = user_type.objects.get(session=self.session.session_key) + + if user.session.auth_factors.filter( + auth_factor__type=AuthFactor.Type.OTP + ).exists(): + now = timezone.now() + otp = user.totp.at(now) + with patch.object(timezone, "now", return_value=now): + assert super().login( + request=self.request_factory.post( + user=t.cast(RequestUser, user) + ), + otp=otp, + ), f'Failed to login with OTP "{otp}" at {now}.' + + assert user.is_authenticated, "Failed to authenticate user." + + return user + + def login(self, **credentials): + """Log in a user. + + Returns: + The user. + """ + return self._login_user_type(User, **credentials) + + def login_teacher(self, email: str, password: str = "password"): + """Log in a user and assert they are a teacher. + + Args: + email: The user's email address. + password: The user's password. + + Returns: + The teacher-user. + """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import TeacherUser + + return self._login_user_type( + TeacherUser, email=email, password=password + ) + + def login_school_teacher(self, email: str, password: str = "password"): + """Log in a user and assert they are a school-teacher. + + Args: + email: The user's email address. + password: The user's password. + + Returns: + The school-teacher-user. + """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import SchoolTeacherUser + + return self._login_user_type( + SchoolTeacherUser, email=email, password=password + ) + + def login_admin_school_teacher( + self, email: str, password: str = "password" + ): + """Log in a user and assert they are an admin-school-teacher. + + Args: + email: The user's email address. + password: The user's password. + + Returns: + The admin-school-teacher-user. + """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import AdminSchoolTeacherUser + + return self._login_user_type( + AdminSchoolTeacherUser, email=email, password=password + ) + + def login_non_admin_school_teacher( + self, email: str, password: str = "password" + ): + """Log in a user and assert they are a non-admin-school-teacher. + + Args: + email: The user's email address. + password: The user's password. + + Returns: + The non-admin-school-teacher-user. + """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import NonAdminSchoolTeacherUser + + return self._login_user_type( + NonAdminSchoolTeacherUser, email=email, password=password + ) + + def login_non_school_teacher(self, email: str, password: str = "password"): + """Log in a user and assert they are a non-school-teacher. + + Args: + email: The user's email address. + password: The user's password. + + Returns: + The non-school-teacher-user. + """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import NonSchoolTeacherUser + + return self._login_user_type( + NonSchoolTeacherUser, email=email, password=password + ) + + def login_student( + self, class_id: str, first_name: str, password: str = "password" + ): + """Log in a user and assert they are a student. + + Args: + class_id: The ID of the class the student belongs to. + first_name: The user's first name. + password: The user's password. + + Returns: + The student-user. + """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import StudentUser + + return self._login_user_type( + StudentUser, + first_name=first_name, + password=password, + class_id=class_id, + ) + + def login_indy(self, email: str, password: str = "password"): + """Log in a user and assert they are an independent. + + Args: + email: The user's email address. + password: The user's password. + + Returns: + The independent-user. + """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import IndependentUser + + return self._login_user_type( + IndependentUser, email=email, password=password + ) + + def login_as(self, user: "TypedUser", password: str = "password"): + """Log in as a user. The user instance needs to be a user proxy in order + to know which credentials are required. + + Args: + user: The user to log in as. + password: The user's password. + """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import ( + AdminSchoolTeacherUser, + IndependentUser, + NonAdminSchoolTeacherUser, + NonSchoolTeacherUser, + SchoolTeacherUser, + StudentUser, + TeacherUser, + ) + + auth_user = None + + if isinstance(user, AdminSchoolTeacherUser): + auth_user = self.login_admin_school_teacher(user.email, password) + elif isinstance(user, NonAdminSchoolTeacherUser): + auth_user = self.login_non_admin_school_teacher( + user.email, password + ) + elif isinstance(user, SchoolTeacherUser): + auth_user = self.login_school_teacher(user.email, password) + elif isinstance(user, NonSchoolTeacherUser): + auth_user = self.login_non_school_teacher(user.email, password) + elif isinstance(user, TeacherUser): + auth_user = self.login_teacher(user.email, password) + elif isinstance(user, StudentUser): + auth_user = self.login_student( + user.student.class_field.access_code, + user.first_name, + password, + ) + elif isinstance(user, IndependentUser): + auth_user = self.login_indy(user.email, password) + + assert user == auth_user + + # -------------------------------------------------------------------------- + # Request Helpers + # -------------------------------------------------------------------------- + + StatusCodeAssertion = t.Optional[t.Union[int, t.Callable[[int], bool]]] + + # pylint: disable=too-many-arguments,redefined-builtin + + def generic( + self, + method, + path, + data="", + content_type="application/json", + secure=False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + response = t.cast( + Response, + super().generic( + method, + path, + data, + content_type, + secure, + **extra, + ), + ) + + # Use a custom kwarg to handle the common case of checking the + # response's status code. + if status_code_assertion is None: + status_code_assertion = self.status_code_is_ok + elif isinstance(status_code_assertion, int): + expected_status_code = status_code_assertion + status_code_assertion = ( + # pylint: disable-next=unnecessary-lambda-assignment + lambda status_code: status_code + == expected_status_code + ) + + # pylint: disable-next=no-member + status_code = response.status_code + assert status_code_assertion( + status_code + ), f"Unexpected status code: {status_code}." + ( + "\nValidation errors: " + + json.dumps( + # pylint: disable-next=no-member + response.json(), # type: ignore[attr-defined] + indent=2, + default=str, + ) + if status_code == status.HTTP_400_BAD_REQUEST + else "" + ) + + return response + + def get( # type: ignore[override] + self, + path: str, + data: t.Any = None, + follow: bool = False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + return super().get( + path=path, + data=data, + follow=follow, + status_code_assertion=status_code_assertion, + **extra, + ) + + def post( # type: ignore[override] + self, + path: str, + data: t.Any = None, + format: t.Optional[str] = None, + content_type: t.Optional[str] = None, + follow: bool = False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + if format is None and content_type is None: + format = "json" + + return super().post( + path=path, + data=data, + format=format, + content_type=content_type, + follow=follow, + status_code_assertion=status_code_assertion, + **extra, + ) + + def put( # type: ignore[override] + self, + path: str, + data: t.Any = None, + format: t.Optional[str] = None, + content_type: t.Optional[str] = None, + follow: bool = False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + if format is None and content_type is None: + format = "json" + + return super().put( + path=path, + data=data, + format=format, + content_type=content_type, + follow=follow, + status_code_assertion=status_code_assertion, + **extra, + ) + + def patch( # type: ignore[override] + self, + path: str, + data: t.Any = None, + format: t.Optional[str] = None, + content_type: t.Optional[str] = None, + follow: bool = False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + if format is None and content_type is None: + format = "json" + + return super().patch( + path=path, + data=data, + format=format, + content_type=content_type, + follow=follow, + status_code_assertion=status_code_assertion, + **extra, + ) + + def delete( # type: ignore[override] + self, + path: str, + data: t.Any = None, + format: t.Optional[str] = None, + content_type: t.Optional[str] = None, + follow: bool = False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + if format is None and content_type is None: + format = "json" + + return super().delete( + path=path, + data=data, + format=format, + content_type=content_type, + follow=follow, + status_code_assertion=status_code_assertion, + **extra, + ) + + def options( # type: ignore[override] + self, + path: str, + data: t.Any = None, + format: t.Optional[str] = None, + content_type: t.Optional[str] = None, + follow: bool = False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + if format is None and content_type is None: + format = "json" + + return super().options( + path=path, + data=data, + format=format, + content_type=content_type, + follow=follow, + status_code_assertion=status_code_assertion, + **extra, + ) + + # pylint: enable=too-many-arguments,redefined-builtin From 5b1d3bdff49023c04b26421ee52e269f25d22b1f Mon Sep 17 00:00:00 2001 From: SKairinos Date: Wed, 6 Nov 2024 14:12:26 +0000 Subject: [PATCH 28/56] fix: abstract api test case and client --- codeforlife/tests/__init__.py | 4 +- codeforlife/tests/api.py | 19 +- codeforlife/tests/api_client.py | 457 +++++++++++++++++--------------- 3 files changed, 261 insertions(+), 219 deletions(-) diff --git a/codeforlife/tests/__init__.py b/codeforlife/tests/__init__.py index 38f4ee00..58dbce7c 100644 --- a/codeforlife/tests/__init__.py +++ b/codeforlife/tests/__init__.py @@ -5,8 +5,8 @@ Custom test cases. """ -from .api import APITestCase -from .api_client import APIClient +from .api import APITestCase, BaseAPITestCase +from .api_client import APIClient, BaseAPIClient from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory from .cron import CronTestCase from .model import ModelTestCase diff --git a/codeforlife/tests/api.py b/codeforlife/tests/api.py index 58aec421..2b81f0d0 100644 --- a/codeforlife/tests/api.py +++ b/codeforlife/tests/api.py @@ -5,7 +5,7 @@ import typing as t -from .api_client import APIClient +from .api_client import APIClient, BaseAPIClient from .test import TestCase # pylint: disable=duplicate-code @@ -15,14 +15,25 @@ RequestUser = t.TypeVar("RequestUser", bound=User) else: RequestUser = t.TypeVar("RequestUser") + +AnyBaseAPIClient = t.TypeVar("AnyBaseAPIClient", bound=BaseAPIClient) # pylint: enable=duplicate-code -class APITestCase(TestCase, t.Generic[RequestUser]): +class BaseAPITestCase(TestCase, t.Generic[AnyBaseAPIClient]): + """Base API test case to be inherited by all other API test cases.""" + + client: AnyBaseAPIClient + client_class: t.Type[AnyBaseAPIClient] + + +class APITestCase( + BaseAPITestCase[APIClient[RequestUser]], + t.Generic[RequestUser], +): """Base API test case to be inherited by all other API test cases.""" - client: APIClient[RequestUser] - client_class: t.Type[APIClient[RequestUser]] = APIClient + client_class = APIClient @classmethod def get_request_user_class(cls) -> t.Type[RequestUser]: diff --git a/codeforlife/tests/api_client.py b/codeforlife/tests/api_client.py index 31aa4c4e..b2113504 100644 --- a/codeforlife/tests/api_client.py +++ b/codeforlife/tests/api_client.py @@ -13,25 +13,42 @@ from rest_framework.test import APIClient as _APIClient from ..types import DataDict, JsonDict -from .api_request_factory import APIRequestFactory +from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory # pylint: disable=duplicate-code if t.TYPE_CHECKING: from ..user.models import TypedUser, User - from .api import APITestCase + from .api import APITestCase, BaseAPITestCase RequestUser = t.TypeVar("RequestUser", bound=User) LoginUser = t.TypeVar("LoginUser", bound=User) + AnyBaseAPITestCase = t.TypeVar("AnyBaseAPITestCase", bound=BaseAPITestCase) else: RequestUser = t.TypeVar("RequestUser") LoginUser = t.TypeVar("LoginUser") + AnyBaseAPITestCase = t.TypeVar("AnyBaseAPITestCase") + +AnyBaseAPIRequestFactory = t.TypeVar( + "AnyBaseAPIRequestFactory", bound=BaseAPIRequestFactory +) # pylint: enable=duplicate-code -class APIClient(_APIClient, t.Generic[RequestUser]): +class BaseAPIClient( + _APIClient, + t.Generic[AnyBaseAPITestCase, AnyBaseAPIRequestFactory], +): """Base API client to be inherited by all other API clients.""" - _test_case: "APITestCase[RequestUser]" + _test_case: AnyBaseAPITestCase + + request_factory: AnyBaseAPIRequestFactory + request_factory_class: t.Type[AnyBaseAPIRequestFactory] + + def _initialize_request_factory( + self, enforce_csrf_checks: bool, **defaults + ): + return self.request_factory_class(enforce_csrf_checks, **defaults) def __init__( self, @@ -45,24 +62,10 @@ def __init__( **defaults, ) - self.request_factory = APIRequestFactory( - self.get_request_user_class(), - enforce_csrf_checks, - **defaults, + self.request_factory = self._initialize_request_factory( + enforce_csrf_checks, **defaults ) - @classmethod - def get_request_user_class(cls) -> t.Type[RequestUser]: - """Get the request's user class. - - Returns: - The request's user class. - """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] - @staticmethod def status_code_is_ok(status_code: int): """Check if the status code is greater than or equal to 200 and less @@ -110,6 +113,227 @@ def _make_assertions(): self._assert_response(response, _make_assertions) + # -------------------------------------------------------------------------- + # Request Helpers + # -------------------------------------------------------------------------- + + StatusCodeAssertion = t.Optional[t.Union[int, t.Callable[[int], bool]]] + + # pylint: disable=too-many-arguments,redefined-builtin + + def generic( + self, + method, + path, + data="", + content_type="application/json", + secure=False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + response = t.cast( + Response, + super().generic( + method, + path, + data, + content_type, + secure, + **extra, + ), + ) + + # Use a custom kwarg to handle the common case of checking the + # response's status code. + if status_code_assertion is None: + status_code_assertion = self.status_code_is_ok + elif isinstance(status_code_assertion, int): + expected_status_code = status_code_assertion + status_code_assertion = ( + # pylint: disable-next=unnecessary-lambda-assignment + lambda status_code: status_code + == expected_status_code + ) + + # pylint: disable-next=no-member + status_code = response.status_code + assert status_code_assertion( + status_code + ), f"Unexpected status code: {status_code}." + ( + "\nValidation errors: " + + json.dumps( + # pylint: disable-next=no-member + response.json(), # type: ignore[attr-defined] + indent=2, + default=str, + ) + if status_code == status.HTTP_400_BAD_REQUEST + else "" + ) + + return response + + def get( # type: ignore[override] + self, + path: str, + data: t.Any = None, + follow: bool = False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + return super().get( + path=path, + data=data, + follow=follow, + status_code_assertion=status_code_assertion, + **extra, + ) + + def post( # type: ignore[override] + self, + path: str, + data: t.Any = None, + format: t.Optional[str] = None, + content_type: t.Optional[str] = None, + follow: bool = False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + if format is None and content_type is None: + format = "json" + + return super().post( + path=path, + data=data, + format=format, + content_type=content_type, + follow=follow, + status_code_assertion=status_code_assertion, + **extra, + ) + + def put( # type: ignore[override] + self, + path: str, + data: t.Any = None, + format: t.Optional[str] = None, + content_type: t.Optional[str] = None, + follow: bool = False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + if format is None and content_type is None: + format = "json" + + return super().put( + path=path, + data=data, + format=format, + content_type=content_type, + follow=follow, + status_code_assertion=status_code_assertion, + **extra, + ) + + def patch( # type: ignore[override] + self, + path: str, + data: t.Any = None, + format: t.Optional[str] = None, + content_type: t.Optional[str] = None, + follow: bool = False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + if format is None and content_type is None: + format = "json" + + return super().patch( + path=path, + data=data, + format=format, + content_type=content_type, + follow=follow, + status_code_assertion=status_code_assertion, + **extra, + ) + + def delete( # type: ignore[override] + self, + path: str, + data: t.Any = None, + format: t.Optional[str] = None, + content_type: t.Optional[str] = None, + follow: bool = False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + if format is None and content_type is None: + format = "json" + + return super().delete( + path=path, + data=data, + format=format, + content_type=content_type, + follow=follow, + status_code_assertion=status_code_assertion, + **extra, + ) + + def options( # type: ignore[override] + self, + path: str, + data: t.Any = None, + format: t.Optional[str] = None, + content_type: t.Optional[str] = None, + follow: bool = False, + status_code_assertion: StatusCodeAssertion = None, + **extra, + ): + if format is None and content_type is None: + format = "json" + + return super().options( + path=path, + data=data, + format=format, + content_type=content_type, + follow=follow, + status_code_assertion=status_code_assertion, + **extra, + ) + + # pylint: enable=too-many-arguments,redefined-builtin + + +class APIClient( + BaseAPIClient["APITestCase[RequestUser]", APIRequestFactory[RequestUser]], + t.Generic[RequestUser], +): + """Base API client to be inherited by all other API clients.""" + + request_factory_class = APIRequestFactory + + def _initialize_request_factory(self, enforce_csrf_checks, **defaults): + return self.request_factory_class( + self.get_request_user_class(), + enforce_csrf_checks, + **defaults, + ) + + @classmethod + def get_request_user_class(cls) -> t.Type[RequestUser]: + """Get the request's user class. + + Returns: + The request's user class. + """ + # pylint: disable-next=no-member + return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] + 0 + ] + # -------------------------------------------------------------------------- # Login Helpers # -------------------------------------------------------------------------- @@ -323,196 +547,3 @@ def login_as(self, user: "TypedUser", password: str = "password"): auth_user = self.login_indy(user.email, password) assert user == auth_user - - # -------------------------------------------------------------------------- - # Request Helpers - # -------------------------------------------------------------------------- - - StatusCodeAssertion = t.Optional[t.Union[int, t.Callable[[int], bool]]] - - # pylint: disable=too-many-arguments,redefined-builtin - - def generic( - self, - method, - path, - data="", - content_type="application/json", - secure=False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - response = t.cast( - Response, - super().generic( - method, - path, - data, - content_type, - secure, - **extra, - ), - ) - - # Use a custom kwarg to handle the common case of checking the - # response's status code. - if status_code_assertion is None: - status_code_assertion = self.status_code_is_ok - elif isinstance(status_code_assertion, int): - expected_status_code = status_code_assertion - status_code_assertion = ( - # pylint: disable-next=unnecessary-lambda-assignment - lambda status_code: status_code - == expected_status_code - ) - - # pylint: disable-next=no-member - status_code = response.status_code - assert status_code_assertion( - status_code - ), f"Unexpected status code: {status_code}." + ( - "\nValidation errors: " - + json.dumps( - # pylint: disable-next=no-member - response.json(), # type: ignore[attr-defined] - indent=2, - default=str, - ) - if status_code == status.HTTP_400_BAD_REQUEST - else "" - ) - - return response - - def get( # type: ignore[override] - self, - path: str, - data: t.Any = None, - follow: bool = False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - return super().get( - path=path, - data=data, - follow=follow, - status_code_assertion=status_code_assertion, - **extra, - ) - - def post( # type: ignore[override] - self, - path: str, - data: t.Any = None, - format: t.Optional[str] = None, - content_type: t.Optional[str] = None, - follow: bool = False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - if format is None and content_type is None: - format = "json" - - return super().post( - path=path, - data=data, - format=format, - content_type=content_type, - follow=follow, - status_code_assertion=status_code_assertion, - **extra, - ) - - def put( # type: ignore[override] - self, - path: str, - data: t.Any = None, - format: t.Optional[str] = None, - content_type: t.Optional[str] = None, - follow: bool = False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - if format is None and content_type is None: - format = "json" - - return super().put( - path=path, - data=data, - format=format, - content_type=content_type, - follow=follow, - status_code_assertion=status_code_assertion, - **extra, - ) - - def patch( # type: ignore[override] - self, - path: str, - data: t.Any = None, - format: t.Optional[str] = None, - content_type: t.Optional[str] = None, - follow: bool = False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - if format is None and content_type is None: - format = "json" - - return super().patch( - path=path, - data=data, - format=format, - content_type=content_type, - follow=follow, - status_code_assertion=status_code_assertion, - **extra, - ) - - def delete( # type: ignore[override] - self, - path: str, - data: t.Any = None, - format: t.Optional[str] = None, - content_type: t.Optional[str] = None, - follow: bool = False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - if format is None and content_type is None: - format = "json" - - return super().delete( - path=path, - data=data, - format=format, - content_type=content_type, - follow=follow, - status_code_assertion=status_code_assertion, - **extra, - ) - - def options( # type: ignore[override] - self, - path: str, - data: t.Any = None, - format: t.Optional[str] = None, - content_type: t.Optional[str] = None, - follow: bool = False, - status_code_assertion: StatusCodeAssertion = None, - **extra, - ): - if format is None and content_type is None: - format = "json" - - return super().options( - path=path, - data=data, - format=format, - content_type=content_type, - follow=follow, - status_code_assertion=status_code_assertion, - **extra, - ) - - # pylint: enable=too-many-arguments,redefined-builtin From 0f503e9437ef6bf98c82993e382215628af61f69 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Wed, 6 Nov 2024 14:14:38 +0000 Subject: [PATCH 29/56] # pylint: disable-next=too-many-ancestors --- codeforlife/tests/model_view_set.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/codeforlife/tests/model_view_set.py b/codeforlife/tests/model_view_set.py index 3a373021..97e1cb0b 100644 --- a/codeforlife/tests/model_view_set.py +++ b/codeforlife/tests/model_view_set.py @@ -35,6 +35,7 @@ # pylint: disable=no-member,too-many-arguments +# pylint: disable-next=too-many-ancestors class ModelViewSetClient( APIClient[RequestUser], t.Generic[RequestUser, AnyModel] ): @@ -613,6 +614,7 @@ def cron_job(self, action: str): # pylint: enable=no-member +# pylint: disable-next=too-many-ancestors class ModelViewSetTestCase( APITestCase[RequestUser], t.Generic[RequestUser, AnyModel] ): From 33d68a91816410af027df3bf7caf81cb2ab9c76c Mon Sep 17 00:00:00 2001 From: SKairinos Date: Wed, 6 Nov 2024 14:18:30 +0000 Subject: [PATCH 30/56] split code --- codeforlife/tests/__init__.py | 3 +- codeforlife/tests/model_view_set.py | 587 +------------------- codeforlife/tests/model_view_set_client.py | 608 +++++++++++++++++++++ 3 files changed, 612 insertions(+), 586 deletions(-) create mode 100644 codeforlife/tests/model_view_set_client.py diff --git a/codeforlife/tests/__init__.py b/codeforlife/tests/__init__.py index 58dbce7c..a6d0ef65 100644 --- a/codeforlife/tests/__init__.py +++ b/codeforlife/tests/__init__.py @@ -18,5 +18,6 @@ BaseModelSerializerTestCase, ModelSerializerTestCase, ) -from .model_view_set import ModelViewSetClient, ModelViewSetTestCase +from .model_view_set import ModelViewSetTestCase +from .model_view_set_client import ModelViewSetClient from .test import Client, TestCase diff --git a/codeforlife/tests/model_view_set.py b/codeforlife/tests/model_view_set.py index 97e1cb0b..58b0e668 100644 --- a/codeforlife/tests/model_view_set.py +++ b/codeforlife/tests/model_view_set.py @@ -10,17 +10,14 @@ from django.core.exceptions import ObjectDoesNotExist from django.db.models import Model -from django.db.models.query import QuerySet from django.urls import reverse -from django.utils.http import urlencode -from rest_framework import status -from rest_framework.response import Response from ..permissions import Permission from ..serializers import BaseSerializer from ..types import DataDict, JsonDict, KwArgs from ..views import ModelViewSet -from .api import APIClient, APITestCase +from .api import APITestCase +from .model_view_set_client import ModelViewSetClient # pylint: disable-next=duplicate-code if t.TYPE_CHECKING: @@ -31,589 +28,9 @@ RequestUser = t.TypeVar("RequestUser") AnyModel = t.TypeVar("AnyModel", bound=Model) - # pylint: disable=no-member,too-many-arguments -# pylint: disable-next=too-many-ancestors -class ModelViewSetClient( - APIClient[RequestUser], t.Generic[RequestUser, AnyModel] -): - """ - An API client that helps make requests to a model view set and assert their - responses. - """ - - _test_case: "ModelViewSetTestCase[RequestUser, AnyModel]" - - @property - def _model_class(self): - """Shortcut to get model class.""" - return self._test_case.get_model_class() - - @property - def _model_view_set_class(self): - """Shortcut to get model view set class.""" - return self._test_case.model_view_set_class - - # -------------------------------------------------------------------------- - # Create (HTTP POST) - # -------------------------------------------------------------------------- - - def _assert_create(self, json_model: JsonDict, action: str): - model = self._model_class.objects.get( - **{self._model_view_set_class.lookup_field: json_model["id"]} - ) - self._test_case.assert_serialized_model_equals_json_model( - model, json_model, action, request_method="post" - ) - - def create( - self, - data: DataDict, - status_code_assertion: APIClient.StatusCodeAssertion = ( - status.HTTP_201_CREATED - ), - make_assertions: bool = True, - reverse_kwargs: t.Optional[KwArgs] = None, - **kwargs, - ): - # pylint: disable=line-too-long - """Create a model. - - Args: - data: The values for each field. - status_code_assertion: The expected status code. - make_assertions: A flag designating whether to make the default assertions. - reverse_kwargs: The kwargs for the reverse URL. - - Returns: - The HTTP response. - """ - # pylint: enable=line-too-long - - response: Response = self.post( - self._test_case.reverse_action("list", kwargs=reverse_kwargs), - data=data, - status_code_assertion=status_code_assertion, - **kwargs, - ) - - if make_assertions: - self._assert_response_json( - response, - lambda json_model: self._assert_create( - json_model, action="create" - ), - ) - - return response - - def bulk_create( - self, - data: t.List[DataDict], - status_code_assertion: APIClient.StatusCodeAssertion = ( - status.HTTP_201_CREATED - ), - make_assertions: bool = True, - reverse_kwargs: t.Optional[KwArgs] = None, - **kwargs, - ): - # pylint: disable=line-too-long - """Bulk create many instances of a model. - - Args: - data: The values for each field, for each model. - status_code_assertion: The expected status code. - make_assertions: A flag designating whether to make the default assertions. - reverse_kwargs: The kwargs for the reverse URL. - - Returns: - The HTTP response. - """ - # pylint: enable=line-too-long - - response: Response = self.post( - self._test_case.reverse_action("bulk", kwargs=reverse_kwargs), - data=data, - status_code_assertion=status_code_assertion, - **kwargs, - ) - - if make_assertions: - - def _make_assertions(json_models: t.List[JsonDict]): - for json_model in json_models: - self._assert_create(json_model, action="bulk") - - self._assert_response_json_bulk(response, _make_assertions, data) - - return response - - # -------------------------------------------------------------------------- - # Retrieve (HTTP GET) - # -------------------------------------------------------------------------- - - def retrieve( - self, - model: AnyModel, - status_code_assertion: APIClient.StatusCodeAssertion = ( - status.HTTP_200_OK - ), - make_assertions: bool = True, - reverse_kwargs: t.Optional[KwArgs] = None, - **kwargs, - ): - # pylint: disable=line-too-long - """Retrieve a model. - - Args: - model: The model to retrieve. - status_code_assertion: The expected status code. - make_assertions: A flag designating whether to make the default assertions. - reverse_kwargs: The kwargs for the reverse URL. - - Returns: - The HTTP response. - """ - # pylint: enable=line-too-long - - response: Response = self.get( - self._test_case.reverse_action( - "detail", - model, - kwargs=reverse_kwargs, - ), - status_code_assertion=status_code_assertion, - **kwargs, - ) - - if make_assertions: - self._assert_response_json( - response, - make_assertions=lambda json_model: ( - self._test_case.assert_serialized_model_equals_json_model( - model, - json_model, - action="retrieve", - request_method="get", - ) - ), - ) - - return response - - def list( - self, - models: t.Collection[AnyModel], - status_code_assertion: APIClient.StatusCodeAssertion = ( - status.HTTP_200_OK - ), - make_assertions: bool = True, - filters: t.Optional[t.Dict[str, t.Union[str, t.Iterable[str]]]] = None, - reverse_kwargs: t.Optional[KwArgs] = None, - **kwargs, - ): - # pylint: disable=line-too-long - """Retrieve a list of models. - - Args: - models: The model list to retrieve. - status_code_assertion: The expected status code. - make_assertions: A flag designating whether to make the default assertions. - filters: The filters to apply to the list. - reverse_kwargs: The kwargs for the reverse URL. - - Returns: - The HTTP response. - """ - # pylint: enable=line-too-long - - query: t.List[t.Tuple[str, str]] = [] - for key, values in (filters or {}).items(): - if isinstance(values, str): - query.append((key, values)) - else: - for value in values: - query.append((key, value)) - - response: Response = self.get( - ( - self._test_case.reverse_action("list", kwargs=reverse_kwargs) - + f"?{urlencode(query)}" - ), - status_code_assertion=status_code_assertion, - **kwargs, - ) - - if make_assertions: - - def _make_assertions(response_json: JsonDict): - json_models = t.cast(t.List[JsonDict], response_json["data"]) - assert len(models) == len(json_models) - for model, json_model in zip(models, json_models): - self._test_case.assert_serialized_model_equals_json_model( - model, json_model, action="list", request_method="get" - ) - - self._assert_response_json(response, _make_assertions) - - return response - - # -------------------------------------------------------------------------- - # Partial Update (HTTP PATCH) - # -------------------------------------------------------------------------- - - def _assert_update( - self, - model: AnyModel, - json_model: JsonDict, - action: str, - request_method: str, - partial: bool, - ): - model.refresh_from_db() - self._test_case.assert_serialized_model_equals_json_model( - model, json_model, action, request_method, contains_subset=partial - ) - - def partial_update( - self, - model: AnyModel, - data: DataDict, - status_code_assertion: APIClient.StatusCodeAssertion = ( - status.HTTP_200_OK - ), - make_assertions: bool = True, - reverse_kwargs: t.Optional[KwArgs] = None, - **kwargs, - ): - # pylint: disable=line-too-long - """Partially update a model. - - Args: - model: The model to partially update. - data: The values for each field. - status_code_assertion: The expected status code. - make_assertions: A flag designating whether to make the default assertions. - reverse_kwargs: The kwargs for the reverse URL. - - Returns: - The HTTP response. - """ - # pylint: enable=line-too-long - response: Response = self.patch( - self._test_case.reverse_action( - "detail", - model, - kwargs=reverse_kwargs, - ), - data=data, - status_code_assertion=status_code_assertion, - **kwargs, - ) - - if make_assertions: - self._assert_response_json( - response, - make_assertions=lambda json_model: self._assert_update( - model, - json_model, - action="partial_update", - request_method="patch", - partial=True, - ), - ) - - return response - - def bulk_partial_update( - self, - models: t.Union[t.List[AnyModel], QuerySet[AnyModel]], - data: t.List[DataDict], - status_code_assertion: APIClient.StatusCodeAssertion = ( - status.HTTP_200_OK - ), - make_assertions: bool = True, - reverse_kwargs: t.Optional[KwArgs] = None, - **kwargs, - ): - # pylint: disable=line-too-long - """Bulk partially update many instances of a model. - - Args: - models: The models to partially update. - data: The values for each field, for each model. - status_code_assertion: The expected status code. - make_assertions: A flag designating whether to make the default assertions. - reverse_kwargs: The kwargs for the reverse URL. - - Returns: - The HTTP response. - """ - # pylint: enable=line-too-long - if not isinstance(models, list): - models = list(models) - - response: Response = self.patch( - self._test_case.reverse_action("bulk", kwargs=reverse_kwargs), - data=data, - status_code_assertion=status_code_assertion, - **kwargs, - ) - - if make_assertions: - - def _make_assertions(json_models: t.List[JsonDict]): - models.sort( - key=lambda model: getattr( - model, self._model_view_set_class.lookup_field - ) - ) - for model, json_model in zip(models, json_models): - self._assert_update( - model, - json_model, - action="bulk", - request_method="patch", - partial=True, - ) - - self._assert_response_json_bulk(response, _make_assertions, data) - - return response - - # -------------------------------------------------------------------------- - # Update (HTTP PUT) - # -------------------------------------------------------------------------- - - def update( - self, - model: AnyModel, - action: str, - data: t.Optional[DataDict] = None, - status_code_assertion: APIClient.StatusCodeAssertion = ( - status.HTTP_200_OK - ), - make_assertions: bool = True, - reverse_kwargs: t.Optional[KwArgs] = None, - **kwargs, - ): - # pylint: disable=line-too-long - """Update a model. - - Args: - model: The model to update. - action: The name of the action. - data: The values for each field. - status_code_assertion: The expected status code. - make_assertions: A flag designating whether to make the default assertions. - reverse_kwargs: The kwargs for the reverse URL. - - Returns: - The HTTP response. - """ - # pylint: enable=line-too-long - response = self.put( - path=self._test_case.reverse_action( - action, model, kwargs=reverse_kwargs - ), - data=data, - status_code_assertion=status_code_assertion, - **kwargs, - ) - - if make_assertions: - self._assert_response_json( - response, - make_assertions=lambda json_model: self._assert_update( - model, - json_model, - action, - request_method="put", - partial=False, - ), - ) - - return response - - def bulk_update( - self, - models: t.Union[t.List[AnyModel], QuerySet[AnyModel]], - data: t.List[DataDict], - action: str, - status_code_assertion: APIClient.StatusCodeAssertion = ( - status.HTTP_200_OK - ), - make_assertions: bool = True, - reverse_kwargs: t.Optional[KwArgs] = None, - **kwargs, - ): - # pylint: disable=line-too-long - """Bulk update many instances of a model. - - Args: - models: The models to update. - data: The values for each field, for each model. - action: The name of the action. - status_code_assertion: The expected status code. - make_assertions: A flag designating whether to make the default assertions. - reverse_kwargs: The kwargs for the reverse URL. - - Returns: - The HTTP response. - """ - # pylint: enable=line-too-long - if not isinstance(models, list): - models = list(models) - - assert models - assert len(models) == len(data) - - response = self.put( - self._test_case.reverse_action(action, kwargs=reverse_kwargs), - data={ - getattr(model, self._model_view_set_class.lookup_field): _data - for model, _data in zip(models, data) - }, - status_code_assertion=status_code_assertion, - **kwargs, - ) - - if make_assertions: - - def _make_assertions(json_models: t.List[JsonDict]): - models.sort( - key=lambda model: getattr( - model, self._model_view_set_class.lookup_field - ) - ) - for model, json_model in zip(models, json_models): - self._assert_update( - model, - json_model, - action, - request_method="put", - partial=False, - ) - - self._assert_response_json_bulk(response, _make_assertions, data) - - return response - - # -------------------------------------------------------------------------- - # Destroy (HTTP DELETE) - # -------------------------------------------------------------------------- - - def _assert_destroy(self, lookup_values: t.List): - assert not self._model_class.objects.filter( - **{f"{self._model_view_set_class.lookup_field}__in": lookup_values} - ).exists() - - def destroy( - self, - model: AnyModel, - status_code_assertion: APIClient.StatusCodeAssertion = ( - status.HTTP_204_NO_CONTENT - ), - make_assertions: bool = True, - reverse_kwargs: t.Optional[KwArgs] = None, - **kwargs, - ): - # pylint: disable=line-too-long - """Destroy a model. - - Args: - model: The model to destroy. - status_code_assertion: The expected status code. - make_assertions: A flag designating whether to make the default assertions. - reverse_kwargs: The kwargs for the reverse URL. - - Returns: - The HTTP response. - """ - # pylint: enable=line-too-long - - response: Response = self.delete( - self._test_case.reverse_action( - "detail", - model, - kwargs=reverse_kwargs, - ), - status_code_assertion=status_code_assertion, - **kwargs, - ) - - if make_assertions: - self._assert_response( - response, - make_assertions=lambda: self._assert_destroy([model.pk]), - ) - - return response - - def bulk_destroy( - self, - data: t.List, - status_code_assertion: APIClient.StatusCodeAssertion = ( - status.HTTP_204_NO_CONTENT - ), - make_assertions: bool = True, - reverse_kwargs: t.Optional[KwArgs] = None, - **kwargs, - ): - # pylint: disable=line-too-long - """Bulk destroy many instances of a model. - - Args: - data: The primary keys of the models to lookup and destroy. - status_code_assertion: The expected status code. - make_assertions: A flag designating whether to make the default assertions. - reverse_kwargs: The kwargs for the reverse URL. - - Returns: - The HTTP response. - """ - # pylint: enable=line-too-long - - response: Response = self.delete( - self._test_case.reverse_action("bulk", kwargs=reverse_kwargs), - data=data, - status_code_assertion=status_code_assertion, - **kwargs, - ) - - if make_assertions: - self._assert_response( - response, make_assertions=lambda: self._assert_destroy(data) - ) - - return response - - # -------------------------------------------------------------------------- - # OTHER - # -------------------------------------------------------------------------- - - def cron_job(self, action: str): - """Call a CRON job. - - Args: - action: The name of the action. - - Returns: - The HTTP response. - """ - response: Response = self.get( - self._test_case.reverse_action(action), - HTTP_X_APPENGINE_CRON="true", - ) - - return response - - -# pylint: enable=no-member - - # pylint: disable-next=too-many-ancestors class ModelViewSetTestCase( APITestCase[RequestUser], t.Generic[RequestUser, AnyModel] diff --git a/codeforlife/tests/model_view_set_client.py b/codeforlife/tests/model_view_set_client.py new file mode 100644 index 00000000..a774bb0d --- /dev/null +++ b/codeforlife/tests/model_view_set_client.py @@ -0,0 +1,608 @@ +""" +© Ocado Group +Created on 06/11/2024 at 14:13:31(+00:00). + +Base test case for all model view clients. +""" + +import typing as t + +from django.db.models import Model +from django.db.models.query import QuerySet +from django.utils.http import urlencode +from rest_framework import status +from rest_framework.response import Response + +from ..types import DataDict, JsonDict, KwArgs +from .api import APIClient + +# pylint: disable-next=duplicate-code +if t.TYPE_CHECKING: + from ..user.models import User + from .model_view_set import ModelViewSetTestCase + + RequestUser = t.TypeVar("RequestUser", bound=User) +else: + RequestUser = t.TypeVar("RequestUser") + +AnyModel = t.TypeVar("AnyModel", bound=Model) +# pylint: disable=no-member,too-many-arguments + + +# pylint: disable-next=too-many-ancestors +class ModelViewSetClient( + APIClient[RequestUser], t.Generic[RequestUser, AnyModel] +): + """ + An API client that helps make requests to a model view set and assert their + responses. + """ + + _test_case: "ModelViewSetTestCase[RequestUser, AnyModel]" + + @property + def _model_class(self): + """Shortcut to get model class.""" + return self._test_case.get_model_class() + + @property + def _model_view_set_class(self): + """Shortcut to get model view set class.""" + return self._test_case.model_view_set_class + + # -------------------------------------------------------------------------- + # Create (HTTP POST) + # -------------------------------------------------------------------------- + + def _assert_create(self, json_model: JsonDict, action: str): + model = self._model_class.objects.get( + **{self._model_view_set_class.lookup_field: json_model["id"]} + ) + self._test_case.assert_serialized_model_equals_json_model( + model, json_model, action, request_method="post" + ) + + def create( + self, + data: DataDict, + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_201_CREATED + ), + make_assertions: bool = True, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Create a model. + + Args: + data: The values for each field. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + + response: Response = self.post( + self._test_case.reverse_action("list", kwargs=reverse_kwargs), + data=data, + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + self._assert_response_json( + response, + lambda json_model: self._assert_create( + json_model, action="create" + ), + ) + + return response + + def bulk_create( + self, + data: t.List[DataDict], + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_201_CREATED + ), + make_assertions: bool = True, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Bulk create many instances of a model. + + Args: + data: The values for each field, for each model. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + + response: Response = self.post( + self._test_case.reverse_action("bulk", kwargs=reverse_kwargs), + data=data, + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + + def _make_assertions(json_models: t.List[JsonDict]): + for json_model in json_models: + self._assert_create(json_model, action="bulk") + + self._assert_response_json_bulk(response, _make_assertions, data) + + return response + + # -------------------------------------------------------------------------- + # Retrieve (HTTP GET) + # -------------------------------------------------------------------------- + + def retrieve( + self, + model: AnyModel, + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_200_OK + ), + make_assertions: bool = True, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Retrieve a model. + + Args: + model: The model to retrieve. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + + response: Response = self.get( + self._test_case.reverse_action( + "detail", + model, + kwargs=reverse_kwargs, + ), + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + self._assert_response_json( + response, + make_assertions=lambda json_model: ( + self._test_case.assert_serialized_model_equals_json_model( + model, + json_model, + action="retrieve", + request_method="get", + ) + ), + ) + + return response + + def list( + self, + models: t.Collection[AnyModel], + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_200_OK + ), + make_assertions: bool = True, + filters: t.Optional[t.Dict[str, t.Union[str, t.Iterable[str]]]] = None, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Retrieve a list of models. + + Args: + models: The model list to retrieve. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + filters: The filters to apply to the list. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + + query: t.List[t.Tuple[str, str]] = [] + for key, values in (filters or {}).items(): + if isinstance(values, str): + query.append((key, values)) + else: + for value in values: + query.append((key, value)) + + response: Response = self.get( + ( + self._test_case.reverse_action("list", kwargs=reverse_kwargs) + + f"?{urlencode(query)}" + ), + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + + def _make_assertions(response_json: JsonDict): + json_models = t.cast(t.List[JsonDict], response_json["data"]) + assert len(models) == len(json_models) + for model, json_model in zip(models, json_models): + self._test_case.assert_serialized_model_equals_json_model( + model, json_model, action="list", request_method="get" + ) + + self._assert_response_json(response, _make_assertions) + + return response + + # -------------------------------------------------------------------------- + # Partial Update (HTTP PATCH) + # -------------------------------------------------------------------------- + + def _assert_update( + self, + model: AnyModel, + json_model: JsonDict, + action: str, + request_method: str, + partial: bool, + ): + model.refresh_from_db() + self._test_case.assert_serialized_model_equals_json_model( + model, json_model, action, request_method, contains_subset=partial + ) + + def partial_update( + self, + model: AnyModel, + data: DataDict, + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_200_OK + ), + make_assertions: bool = True, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Partially update a model. + + Args: + model: The model to partially update. + data: The values for each field. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + response: Response = self.patch( + self._test_case.reverse_action( + "detail", + model, + kwargs=reverse_kwargs, + ), + data=data, + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + self._assert_response_json( + response, + make_assertions=lambda json_model: self._assert_update( + model, + json_model, + action="partial_update", + request_method="patch", + partial=True, + ), + ) + + return response + + def bulk_partial_update( + self, + models: t.Union[t.List[AnyModel], QuerySet[AnyModel]], + data: t.List[DataDict], + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_200_OK + ), + make_assertions: bool = True, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Bulk partially update many instances of a model. + + Args: + models: The models to partially update. + data: The values for each field, for each model. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + if not isinstance(models, list): + models = list(models) + + response: Response = self.patch( + self._test_case.reverse_action("bulk", kwargs=reverse_kwargs), + data=data, + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + + def _make_assertions(json_models: t.List[JsonDict]): + models.sort( + key=lambda model: getattr( + model, self._model_view_set_class.lookup_field + ) + ) + for model, json_model in zip(models, json_models): + self._assert_update( + model, + json_model, + action="bulk", + request_method="patch", + partial=True, + ) + + self._assert_response_json_bulk(response, _make_assertions, data) + + return response + + # -------------------------------------------------------------------------- + # Update (HTTP PUT) + # -------------------------------------------------------------------------- + + def update( + self, + model: AnyModel, + action: str, + data: t.Optional[DataDict] = None, + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_200_OK + ), + make_assertions: bool = True, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Update a model. + + Args: + model: The model to update. + action: The name of the action. + data: The values for each field. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + response = self.put( + path=self._test_case.reverse_action( + action, model, kwargs=reverse_kwargs + ), + data=data, + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + self._assert_response_json( + response, + make_assertions=lambda json_model: self._assert_update( + model, + json_model, + action, + request_method="put", + partial=False, + ), + ) + + return response + + def bulk_update( + self, + models: t.Union[t.List[AnyModel], QuerySet[AnyModel]], + data: t.List[DataDict], + action: str, + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_200_OK + ), + make_assertions: bool = True, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Bulk update many instances of a model. + + Args: + models: The models to update. + data: The values for each field, for each model. + action: The name of the action. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + if not isinstance(models, list): + models = list(models) + + assert models + assert len(models) == len(data) + + response = self.put( + self._test_case.reverse_action(action, kwargs=reverse_kwargs), + data={ + getattr(model, self._model_view_set_class.lookup_field): _data + for model, _data in zip(models, data) + }, + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + + def _make_assertions(json_models: t.List[JsonDict]): + models.sort( + key=lambda model: getattr( + model, self._model_view_set_class.lookup_field + ) + ) + for model, json_model in zip(models, json_models): + self._assert_update( + model, + json_model, + action, + request_method="put", + partial=False, + ) + + self._assert_response_json_bulk(response, _make_assertions, data) + + return response + + # -------------------------------------------------------------------------- + # Destroy (HTTP DELETE) + # -------------------------------------------------------------------------- + + def _assert_destroy(self, lookup_values: t.List): + assert not self._model_class.objects.filter( + **{f"{self._model_view_set_class.lookup_field}__in": lookup_values} + ).exists() + + def destroy( + self, + model: AnyModel, + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_204_NO_CONTENT + ), + make_assertions: bool = True, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Destroy a model. + + Args: + model: The model to destroy. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + + response: Response = self.delete( + self._test_case.reverse_action( + "detail", + model, + kwargs=reverse_kwargs, + ), + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + self._assert_response( + response, + make_assertions=lambda: self._assert_destroy([model.pk]), + ) + + return response + + def bulk_destroy( + self, + data: t.List, + status_code_assertion: APIClient.StatusCodeAssertion = ( + status.HTTP_204_NO_CONTENT + ), + make_assertions: bool = True, + reverse_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Bulk destroy many instances of a model. + + Args: + data: The primary keys of the models to lookup and destroy. + status_code_assertion: The expected status code. + make_assertions: A flag designating whether to make the default assertions. + reverse_kwargs: The kwargs for the reverse URL. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + + response: Response = self.delete( + self._test_case.reverse_action("bulk", kwargs=reverse_kwargs), + data=data, + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if make_assertions: + self._assert_response( + response, make_assertions=lambda: self._assert_destroy(data) + ) + + return response + + # -------------------------------------------------------------------------- + # OTHER + # -------------------------------------------------------------------------- + + def cron_job(self, action: str): + """Call a CRON job. + + Args: + action: The name of the action. + + Returns: + The HTTP response. + """ + response: Response = self.get( + self._test_case.reverse_action(action), + HTTP_X_APPENGINE_CRON="true", + ) + + return response + + +# pylint: enable=no-member From b05850582bd62acf088875258c77509e490c31f8 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Wed, 6 Nov 2024 19:04:24 +0000 Subject: [PATCH 31/56] abstract user and session --- codeforlife/models/__init__.py | 3 + codeforlife/models/abstract_base_session.py | 60 ++++++++++++++ codeforlife/models/abstract_base_user.py | 33 ++++++++ codeforlife/models/base_session_store.py | 89 +++++++++++++++++++++ codeforlife/user/models/session.py | 68 +++------------- codeforlife/user/models/user.py | 13 ++- 6 files changed, 210 insertions(+), 56 deletions(-) create mode 100644 codeforlife/models/abstract_base_session.py create mode 100644 codeforlife/models/abstract_base_user.py create mode 100644 codeforlife/models/base_session_store.py diff --git a/codeforlife/models/__init__.py b/codeforlife/models/__init__.py index 4335ad17..a8faab8f 100644 --- a/codeforlife/models/__init__.py +++ b/codeforlife/models/__init__.py @@ -3,4 +3,7 @@ Created on 19/01/2024 at 15:20:45(+00:00). """ +from .abstract_base_session import AbstractBaseSession +from .abstract_base_user import AbstractBaseUser from .base import * +from .base_session_store import BaseSessionStore diff --git a/codeforlife/models/abstract_base_session.py b/codeforlife/models/abstract_base_session.py new file mode 100644 index 00000000..a7522753 --- /dev/null +++ b/codeforlife/models/abstract_base_session.py @@ -0,0 +1,60 @@ +""" +© Ocado Group +Created on 06/11/2024 at 16:44:56(+00:00). +""" + +import typing as t + +from django.contrib.auth import get_user_model +from django.contrib.sessions.base_session import ( + AbstractBaseSession as _AbstractBaseSession, +) +from django.db import models +from django.utils import timezone +from django.utils.translation import gettext_lazy as _ + +from .abstract_base_user import AbstractBaseUser + +if t.TYPE_CHECKING: + from django_stubs_ext.db.models import TypedModelMeta + + from .base_session_store import BaseSessionStore +else: + TypedModelMeta = object + + +class AbstractBaseSession(_AbstractBaseSession): + """ + Base session class to be inherited by all session classes. + https://docs.djangoproject.com/en/3.2/topics/http/sessions/#example + """ + + pk: str # type: ignore[assignment] + + user_id: int + user = models.OneToOneField( + t.cast(t.Type[AbstractBaseUser], get_user_model()), + null=True, + blank=True, + on_delete=models.CASCADE, + ) + + # pylint: disable-next=missing-class-docstring,too-few-public-methods + class Meta(TypedModelMeta): + abstract = True + verbose_name = _("session") + verbose_name_plural = _("sessions") + + @property + def is_expired(self): + """Whether or not this session has expired.""" + return self.expire_date < timezone.now() + + @property + def store(self): + """A store instance for this session.""" + return self.get_session_store_class()(self.session_key) + + @classmethod + def get_session_store_class(cls) -> t.Type["BaseSessionStore"]: + raise NotImplementedError diff --git a/codeforlife/models/abstract_base_user.py b/codeforlife/models/abstract_base_user.py new file mode 100644 index 00000000..3323fe82 --- /dev/null +++ b/codeforlife/models/abstract_base_user.py @@ -0,0 +1,33 @@ +""" +© Ocado Group +Created on 06/11/2024 at 16:38:15(+00:00). +""" + +import typing as t + +from django.contrib.auth.models import AbstractBaseUser as _AbstractBaseUser +from django.utils.translation import gettext_lazy as _ + +if t.TYPE_CHECKING: + from django_stubs_ext.db.models import TypedModelMeta + + from .abstract_base_session import AbstractBaseSession +else: + TypedModelMeta = object + + +class AbstractBaseUser(_AbstractBaseUser): + """ + Base user class to be inherited by all user classes. + https://docs.djangoproject.com/en/3.2/topics/auth/customizing/#using-a-custom-user-model-when-starting-a-project + """ + + id: int + pk: int + session: "AbstractBaseSession" + + # pylint: disable-next=missing-class-docstring,too-few-public-methods + class Meta(TypedModelMeta): + abstract = True + verbose_name = _("user") + verbose_name_plural = _("users") diff --git a/codeforlife/models/base_session_store.py b/codeforlife/models/base_session_store.py new file mode 100644 index 00000000..facaee5b --- /dev/null +++ b/codeforlife/models/base_session_store.py @@ -0,0 +1,89 @@ +""" +© Ocado Group +Created on 06/11/2024 at 17:31:32(+00:00). +""" + +import typing as t + +from django.contrib.auth import SESSION_KEY +from django.contrib.sessions.backends.db import SessionStore +from django.utils import timezone + +if t.TYPE_CHECKING: + from .abstract_base_session import AbstractBaseSession + from .abstract_base_user import AbstractBaseUser + + AnyAbstractBaseSession = t.TypeVar( + "AnyAbstractBaseSession", bound=AbstractBaseSession + ) + AnyAbstractBaseUser = t.TypeVar( + "AnyAbstractBaseUser", bound=AbstractBaseUser + ) +else: + AnyAbstractBaseSession = t.TypeVar("AnyAbstractBaseSession") + AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser") + + +class BaseSessionStore( + SessionStore, + t.Generic[AnyAbstractBaseSession, AnyAbstractBaseUser], +): + """ + Base session store class to be inherited by all session store classes. + https://docs.djangoproject.com/en/3.2/topics/http/sessions/#example + """ + + @classmethod + def get_model_class(cls) -> t.Type[AnyAbstractBaseSession]: + # pylint: disable-next=no-member + return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] + 0 + ] + + @classmethod + def get_user_class(cls) -> t.Type[AnyAbstractBaseUser]: + # pylint: disable-next=no-member + return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] + 1 + ] + + def associate_session_to_user( + self, session: AnyAbstractBaseSession, user_id: int + ): + """Associate an anon session to a user. + + Args: + session: The anon session. + user_id: The user to associate. + """ + session.user = self.get_user_class().objects.get(id=user_id) + + def create_model_instance(self, data): + try: + user_id = int(data.get(SESSION_KEY)) + except (ValueError, TypeError): + # Create an anon session. + return super().create_model_instance(data) + + model_class = self.get_model_class() + + try: + session = model_class.objects.get(user_id=user_id) + except model_class.DoesNotExist: + session = model_class.objects.get(session_key=self.session_key) + self.associate_session_to_user(session, user_id) + + session.session_data = self.encode(data) + + return session + + @classmethod + def clear_expired(cls, user_id=None): + session_query = cls.get_model_class().objects.filter( + expire_date__lt=timezone.now() + ) + + if user_id is not None: + session_query = session_query.filter(user_id=user_id) + + session_query.delete() diff --git a/codeforlife/user/models/session.py b/codeforlife/user/models/session.py index 25ca724f..a6df0f95 100644 --- a/codeforlife/user/models/session.py +++ b/codeforlife/user/models/session.py @@ -5,13 +5,10 @@ import typing as t -from django.contrib.auth import SESSION_KEY -from django.contrib.sessions.backends.db import SessionStore as DBStore -from django.contrib.sessions.base_session import AbstractBaseSession from django.db import models from django.db.models.query import QuerySet -from django.utils import timezone +from ...models import AbstractBaseSession, BaseSessionStore from .user import User if t.TYPE_CHECKING: # pragma: no cover @@ -26,29 +23,20 @@ class Session(AbstractBaseSession): auth_factors: QuerySet["SessionAuthFactor"] - user = models.OneToOneField( + # TODO: remove in new schema + user = models.OneToOneField( # type: ignore[assignment] User, null=True, blank=True, on_delete=models.CASCADE, ) - @property - def is_expired(self): - """Whether or not this session has expired.""" - return self.expire_date < timezone.now() - - @property - def store(self): - """A store instance for this session.""" - return self.get_session_store_class()(self.session_key) - @classmethod def get_session_store_class(cls): return SessionStore -class SessionStore(DBStore): +class SessionStore(BaseSessionStore[Session, User]): """ A custom session store interface to support: 1. creating only one session per user; @@ -57,44 +45,14 @@ class SessionStore(DBStore): https://docs.djangoproject.com/en/3.2/topics/http/sessions/#example """ - @classmethod - def get_model_class(cls): - return Session - - def create_model_instance(self, data): - try: - user_id = int(data.get(SESSION_KEY)) - except (ValueError, TypeError): - # Create an anon session. - return super().create_model_instance(data) - - model_class = self.get_model_class() + def associate_session_to_user(self, session, user_id): + # pylint: disable-next=import-outside-toplevel + from .session_auth_factor import SessionAuthFactor - try: - session = model_class.objects.get(user_id=user_id) - except model_class.DoesNotExist: - # pylint: disable-next=import-outside-toplevel - from .session_auth_factor import SessionAuthFactor - - # Associate session to user. - session = model_class.objects.get(session_key=self.session_key) - session.user = User.objects.get(id=user_id) - SessionAuthFactor.objects.bulk_create( - [ - SessionAuthFactor(session=session, auth_factor=auth_factor) - for auth_factor in session.user.auth_factors.all() - ] - ) - - session.session_data = self.encode(data) - - return session - - @classmethod - def clear_expired(cls, user_id: t.Optional[int] = None): - session_query = cls.get_model_class().objects.filter( - expire_date__lt=timezone.now() + super().associate_session_to_user(session, user_id) + SessionAuthFactor.objects.bulk_create( + [ + SessionAuthFactor(session=session, auth_factor=auth_factor) + for auth_factor in session.user.auth_factors.all() + ] ) - if user_id: - session_query = session_query.filter(user_id=user_id) - session_query.delete() diff --git a/codeforlife/user/models/user.py b/codeforlife/user/models/user.py index b579cdb5..e36dc08d 100644 --- a/codeforlife/user/models/user.py +++ b/codeforlife/user/models/user.py @@ -6,6 +6,7 @@ """ import string import typing as t +from datetime import datetime from common.models import TotalActivity, UserProfile @@ -18,6 +19,7 @@ from pyotp import TOTP from ... import mail +from ...models import AbstractBaseUser from .klass import Class from .school import School @@ -33,7 +35,16 @@ TypedModelMeta = object -class User(_User): +# TODO: remove in new schema +class _AbstractBaseUser(AbstractBaseUser): + password: str = None # type: ignore[assignment] + last_login: datetime = None # type: ignore[assignment] + + class Meta(TypedModelMeta): + abstract = True + + +class User(_AbstractBaseUser, _User): """A proxy to Django's user class.""" _password: t.Optional[str] From 11cc4fc69ee42417f335547f058408eade952ef6 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Wed, 6 Nov 2024 19:14:34 +0000 Subject: [PATCH 32/56] fix: type hints --- codeforlife/models/abstract_base_session.py | 3 ++- codeforlife/models/abstract_base_user.py | 4 +++- codeforlife/models/base.py | 11 ++++------- codeforlife/models/base_session_store.py | 3 ++- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/codeforlife/models/abstract_base_session.py b/codeforlife/models/abstract_base_session.py index a7522753..85c42d94 100644 --- a/codeforlife/models/abstract_base_session.py +++ b/codeforlife/models/abstract_base_session.py @@ -14,6 +14,7 @@ from django.utils.translation import gettext_lazy as _ from .abstract_base_user import AbstractBaseUser +from .base import Model if t.TYPE_CHECKING: from django_stubs_ext.db.models import TypedModelMeta @@ -23,7 +24,7 @@ TypedModelMeta = object -class AbstractBaseSession(_AbstractBaseSession): +class AbstractBaseSession(Model, _AbstractBaseSession): """ Base session class to be inherited by all session classes. https://docs.djangoproject.com/en/3.2/topics/http/sessions/#example diff --git a/codeforlife/models/abstract_base_user.py b/codeforlife/models/abstract_base_user.py index 3323fe82..2173b1b3 100644 --- a/codeforlife/models/abstract_base_user.py +++ b/codeforlife/models/abstract_base_user.py @@ -8,6 +8,8 @@ from django.contrib.auth.models import AbstractBaseUser as _AbstractBaseUser from django.utils.translation import gettext_lazy as _ +from .base import Model + if t.TYPE_CHECKING: from django_stubs_ext.db.models import TypedModelMeta @@ -16,7 +18,7 @@ TypedModelMeta = object -class AbstractBaseUser(_AbstractBaseUser): +class AbstractBaseUser(Model, _AbstractBaseUser): """ Base user class to be inherited by all user classes. https://docs.djangoproject.com/en/3.2/topics/auth/customizing/#using-a-custom-user-model-when-starting-a-project diff --git a/codeforlife/models/base.py b/codeforlife/models/base.py index 462ba671..0f974c6e 100644 --- a/codeforlife/models/base.py +++ b/codeforlife/models/base.py @@ -7,6 +7,7 @@ import typing as t +from django.db.models import Manager from django.db.models import Model as _Model if t.TYPE_CHECKING: @@ -14,16 +15,12 @@ else: TypedModelMeta = object -Id = t.TypeVar("Id") +class Model(_Model): + """Base for all models.""" -class Model(_Model, t.Generic[Id]): - """A base class for all Django models.""" + objects: Manager[t.Self] - id: Id - pk: Id - - # pylint: disable-next=missing-class-docstring,too-few-public-methods class Meta(TypedModelMeta): abstract = True diff --git a/codeforlife/models/base_session_store.py b/codeforlife/models/base_session_store.py index facaee5b..ad8af54a 100644 --- a/codeforlife/models/base_session_store.py +++ b/codeforlife/models/base_session_store.py @@ -56,7 +56,8 @@ def associate_session_to_user( session: The anon session. user_id: The user to associate. """ - session.user = self.get_user_class().objects.get(id=user_id) + objects = self.get_user_class().objects # type: ignore[attr-defined] + session.user = objects.get(id=user_id) def create_model_instance(self, data): try: From 2664b2efd5566a4908b892406cd627ae9ece8ae3 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Wed, 6 Nov 2024 19:21:33 +0000 Subject: [PATCH 33/56] fix types --- codeforlife/models/abstract_base_session.py | 5 +++-- codeforlife/models/abstract_base_user.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/codeforlife/models/abstract_base_session.py b/codeforlife/models/abstract_base_session.py index 85c42d94..e4d8f4be 100644 --- a/codeforlife/models/abstract_base_session.py +++ b/codeforlife/models/abstract_base_session.py @@ -10,11 +10,11 @@ AbstractBaseSession as _AbstractBaseSession, ) from django.db import models +from django.db.models import Manager from django.utils import timezone from django.utils.translation import gettext_lazy as _ from .abstract_base_user import AbstractBaseUser -from .base import Model if t.TYPE_CHECKING: from django_stubs_ext.db.models import TypedModelMeta @@ -24,13 +24,14 @@ TypedModelMeta = object -class AbstractBaseSession(Model, _AbstractBaseSession): +class AbstractBaseSession(_AbstractBaseSession): """ Base session class to be inherited by all session classes. https://docs.djangoproject.com/en/3.2/topics/http/sessions/#example """ pk: str # type: ignore[assignment] + objects: Manager[t.Self] user_id: int user = models.OneToOneField( diff --git a/codeforlife/models/abstract_base_user.py b/codeforlife/models/abstract_base_user.py index 2173b1b3..66b654bc 100644 --- a/codeforlife/models/abstract_base_user.py +++ b/codeforlife/models/abstract_base_user.py @@ -6,10 +6,9 @@ import typing as t from django.contrib.auth.models import AbstractBaseUser as _AbstractBaseUser +from django.db.models import Manager from django.utils.translation import gettext_lazy as _ -from .base import Model - if t.TYPE_CHECKING: from django_stubs_ext.db.models import TypedModelMeta @@ -18,7 +17,7 @@ TypedModelMeta = object -class AbstractBaseUser(Model, _AbstractBaseUser): +class AbstractBaseUser(_AbstractBaseUser): """ Base user class to be inherited by all user classes. https://docs.djangoproject.com/en/3.2/topics/auth/customizing/#using-a-custom-user-model-when-starting-a-project @@ -27,6 +26,7 @@ class AbstractBaseUser(Model, _AbstractBaseUser): id: int pk: int session: "AbstractBaseSession" + objects: Manager[t.Self] # pylint: disable-next=missing-class-docstring,too-few-public-methods class Meta(TypedModelMeta): From a351a67f3a76f77b5e870b178c5f49e3bc9779b1 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Wed, 6 Nov 2024 19:22:52 +0000 Subject: [PATCH 34/56] fix types --- codeforlife/models/abstract_base_session.py | 2 -- codeforlife/models/abstract_base_user.py | 2 -- codeforlife/models/base_session_store.py | 4 +++- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/codeforlife/models/abstract_base_session.py b/codeforlife/models/abstract_base_session.py index e4d8f4be..a7522753 100644 --- a/codeforlife/models/abstract_base_session.py +++ b/codeforlife/models/abstract_base_session.py @@ -10,7 +10,6 @@ AbstractBaseSession as _AbstractBaseSession, ) from django.db import models -from django.db.models import Manager from django.utils import timezone from django.utils.translation import gettext_lazy as _ @@ -31,7 +30,6 @@ class AbstractBaseSession(_AbstractBaseSession): """ pk: str # type: ignore[assignment] - objects: Manager[t.Self] user_id: int user = models.OneToOneField( diff --git a/codeforlife/models/abstract_base_user.py b/codeforlife/models/abstract_base_user.py index 66b654bc..3323fe82 100644 --- a/codeforlife/models/abstract_base_user.py +++ b/codeforlife/models/abstract_base_user.py @@ -6,7 +6,6 @@ import typing as t from django.contrib.auth.models import AbstractBaseUser as _AbstractBaseUser -from django.db.models import Manager from django.utils.translation import gettext_lazy as _ if t.TYPE_CHECKING: @@ -26,7 +25,6 @@ class AbstractBaseUser(_AbstractBaseUser): id: int pk: int session: "AbstractBaseSession" - objects: Manager[t.Self] # pylint: disable-next=missing-class-docstring,too-few-public-methods class Meta(TypedModelMeta): diff --git a/codeforlife/models/base_session_store.py b/codeforlife/models/base_session_store.py index ad8af54a..af8520e5 100644 --- a/codeforlife/models/base_session_store.py +++ b/codeforlife/models/base_session_store.py @@ -72,7 +72,9 @@ def create_model_instance(self, data): session = model_class.objects.get(user_id=user_id) except model_class.DoesNotExist: session = model_class.objects.get(session_key=self.session_key) - self.associate_session_to_user(session, user_id) + self.associate_session_to_user( + t.cast(AnyAbstractBaseSession, session), user_id + ) session.session_data = self.encode(data) From 4c5857ac944414f600463e7984f50c78b2204794 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 10:05:36 +0000 Subject: [PATCH 35/56] disable too-many-ancestors --- codeforlife/user/models/user.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/codeforlife/user/models/user.py b/codeforlife/user/models/user.py index e36dc08d..9d9a0ad2 100644 --- a/codeforlife/user/models/user.py +++ b/codeforlife/user/models/user.py @@ -192,6 +192,7 @@ def filter_users(self, queryset: QuerySet[User]): return queryset.exclude(email__isnull=True).exclude(email="") +# pylint: disable-next=too-many-ancestors class ContactableUser(User): """A user that can be contacted.""" @@ -275,6 +276,7 @@ def get_queryset(self): return super().get_queryset().prefetch_related("new_teacher") +# pylint: disable-next=too-many-ancestors class TeacherUser(ContactableUser): """A user that is a teacher.""" @@ -488,6 +490,7 @@ def get_queryset(self): return super().get_queryset().prefetch_related("new_student") +# pylint: disable-next=too-many-ancestors class StudentUser(User): """A user that is a student.""" @@ -602,6 +605,7 @@ def create_user( # type: ignore[override] return user +# pylint: disable-next=too-many-ancestors class IndependentUser(ContactableUser): """A user that is an independent learner.""" From 04966b535103d4e5e866664e063bbf271a7259a2 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 10:49:51 +0000 Subject: [PATCH 36/56] fix linting --- codeforlife/models/base_session_store.py | 1 + 1 file changed, 1 insertion(+) diff --git a/codeforlife/models/base_session_store.py b/codeforlife/models/base_session_store.py index af8520e5..81300194 100644 --- a/codeforlife/models/base_session_store.py +++ b/codeforlife/models/base_session_store.py @@ -42,6 +42,7 @@ def get_model_class(cls) -> t.Type[AnyAbstractBaseSession]: @classmethod def get_user_class(cls) -> t.Type[AnyAbstractBaseUser]: + """Get the user class.""" # pylint: disable-next=no-member return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] 1 From ab235bf6c525dd7acbb4d4d94e0712962bf10886 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 11:04:08 +0000 Subject: [PATCH 37/56] abstract model view set test case and client --- codeforlife/tests/model_view_set.py | 118 +++++++++++++++------ codeforlife/tests/model_view_set_client.py | 43 ++++++-- 2 files changed, 119 insertions(+), 42 deletions(-) diff --git a/codeforlife/tests/model_view_set.py b/codeforlife/tests/model_view_set.py index 58b0e668..d52b344e 100644 --- a/codeforlife/tests/model_view_set.py +++ b/codeforlife/tests/model_view_set.py @@ -8,18 +8,20 @@ import typing as t from datetime import datetime +from django.contrib.auth import get_user_model from django.core.exceptions import ObjectDoesNotExist from django.db.models import Model from django.urls import reverse +from ..models import AbstractBaseUser from ..permissions import Permission from ..serializers import BaseSerializer from ..types import DataDict, JsonDict, KwArgs -from ..views import ModelViewSet -from .api import APITestCase -from .model_view_set_client import ModelViewSetClient +from ..views import BaseModelViewSet, ModelViewSet +from .api import APITestCase, BaseAPITestCase +from .model_view_set_client import BaseModelViewSetClient, ModelViewSetClient -# pylint: disable-next=duplicate-code +# pylint: disable=duplicate-code if t.TYPE_CHECKING: from ..user.models import User @@ -28,21 +30,32 @@ RequestUser = t.TypeVar("RequestUser") AnyModel = t.TypeVar("AnyModel", bound=Model) -# pylint: disable=no-member,too-many-arguments +AnyBaseModelViewSetClient = t.TypeVar( + "AnyBaseModelViewSetClient", bound=BaseModelViewSetClient +) +AnyBaseModelViewSet = t.TypeVar("AnyBaseModelViewSet", bound=BaseModelViewSet) +# pylint: enable=duplicate-code -# pylint: disable-next=too-many-ancestors -class ModelViewSetTestCase( - APITestCase[RequestUser], t.Generic[RequestUser, AnyModel] +class BaseModelViewSetTestCase( + BaseAPITestCase[AnyBaseModelViewSetClient], + t.Generic[AnyBaseModelViewSet, AnyBaseModelViewSetClient, AnyModel], ): """Base for all model view set test cases.""" basename: str - model_view_set_class: t.Type[ModelViewSet[RequestUser, AnyModel]] - client: ModelViewSetClient[RequestUser, AnyModel] - client_class: t.Type[ModelViewSetClient[RequestUser, AnyModel]] = ( - ModelViewSetClient - ) + model_view_set_class: t.Type[AnyBaseModelViewSet] + + REQUIRED_ATTRS: t.Set[str] = {"model_view_set_class", "basename"} + + @classmethod + def get_request_user_class(cls): + """Get the request's user class. + + Returns: + The request's user class. + """ + return t.cast(AbstractBaseUser, get_user_model()) @classmethod def get_model_class(cls) -> t.Type[AnyModel]: @@ -53,29 +66,16 @@ def get_model_class(cls) -> t.Type[AnyModel]: """ # pylint: disable-next=no-member return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 1 + 2 ] @classmethod def setUpClass(cls): - for attr in ["model_view_set_class", "basename"]: + for attr in cls.REQUIRED_ATTRS: assert hasattr(cls, attr), f'Attribute "{attr}" must be set.' return super().setUpClass() - def _get_client_class(self): - # TODO: unpack type args in index after moving to python 3.11 - # pylint: disable-next=too-few-public-methods - class _Client( - self.client_class[ # type: ignore[misc] - self.get_request_user_class(), - self.get_model_class(), - ] - ): - _test_case = self - - return _Client - def reverse_action( self, name: str, @@ -113,6 +113,7 @@ def reverse_action( # Assertion Helpers # -------------------------------------------------------------------------- + # pylint: disable-next=too-many-arguments def assert_serialized_model_equals_json_model( self, model: AnyModel, @@ -135,11 +136,8 @@ def assert_serialized_model_equals_json_model( """ # Get the logged-in user. try: - user = t.cast( - RequestUser, - self.get_request_user_class().objects.get( - session=self.client.session.session_key - ), + user = self.get_request_user_class().objects.get( + session=self.client.session.session_key ) except ObjectDoesNotExist: user = None # NOTE: no user has logged in. @@ -147,6 +145,7 @@ def assert_serialized_model_equals_json_model( # Create an instance of the model view set and serializer. model_view_set = self.model_view_set_class( action=action.replace("-", "_"), + # pylint: disable-next=no-member request=self.client.request_factory.generic( request_method, user=user ), @@ -239,6 +238,7 @@ def assert_get_serializer_context( serializer_context: The serializer's context. action: The model view set's action. """ + # pylint: disable-next=no-member kwargs.setdefault("request", self.client.request_factory.get()) kwargs.setdefault("format_kwarg", None) model_view_set = self.model_view_set_class( @@ -249,3 +249,55 @@ def assert_get_serializer_context( actual_serializer_context | serializer_context, actual_serializer_context, ) + + +# pylint: disable-next=too-many-ancestors +class ModelViewSetTestCase( + BaseModelViewSetTestCase[ + ModelViewSet[RequestUser, AnyModel], + ModelViewSetClient[RequestUser, AnyModel], + AnyModel, + ], + APITestCase[RequestUser], + t.Generic[RequestUser, AnyModel], +): + """Base for all model view set test cases.""" + + client_class = ModelViewSetClient + + @classmethod + def get_request_user_class(cls) -> t.Type[RequestUser]: + """Get the request's user class. + + Returns: + The request's user class. + """ + # pylint: disable-next=no-member + return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] + 0 + ] + + @classmethod + def get_model_class(cls) -> t.Type[AnyModel]: + """Get the model view set's class. + + Returns: + The model view set's class. + """ + # pylint: disable-next=no-member + return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] + 1 + ] + + def _get_client_class(self): + # TODO: unpack type args in index after moving to python 3.11 + # pylint: disable-next=too-few-public-methods + class _Client( + self.client_class[ # type: ignore[misc] + self.get_request_user_class(), + self.get_model_class(), + ] + ): + _test_case = self + + return _Client diff --git a/codeforlife/tests/model_view_set_client.py b/codeforlife/tests/model_view_set_client.py index a774bb0d..fa97765b 100644 --- a/codeforlife/tests/model_view_set_client.py +++ b/codeforlife/tests/model_view_set_client.py @@ -14,32 +14,39 @@ from rest_framework.response import Response from ..types import DataDict, JsonDict, KwArgs -from .api import APIClient +from .api import APIClient, BaseAPIClient +from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory -# pylint: disable-next=duplicate-code +# pylint: disable=duplicate-code if t.TYPE_CHECKING: from ..user.models import User - from .model_view_set import ModelViewSetTestCase + from .model_view_set import BaseModelViewSetTestCase, ModelViewSetTestCase RequestUser = t.TypeVar("RequestUser", bound=User) + AnyBaseModelViewSetTestCase = t.TypeVar( + "AnyBaseModelViewSetTestCase", bound=BaseModelViewSetTestCase + ) else: RequestUser = t.TypeVar("RequestUser") + AnyBaseModelViewSetTestCase = t.TypeVar("AnyBaseModelViewSetTestCase") AnyModel = t.TypeVar("AnyModel", bound=Model) -# pylint: disable=no-member,too-many-arguments +AnyBaseAPIRequestFactory = t.TypeVar( + "AnyBaseAPIRequestFactory", bound=BaseAPIRequestFactory +) +# pylint: enable=duplicate-code # pylint: disable-next=too-many-ancestors -class ModelViewSetClient( - APIClient[RequestUser], t.Generic[RequestUser, AnyModel] +class BaseModelViewSetClient( + BaseAPIClient[AnyBaseModelViewSetTestCase, AnyBaseAPIRequestFactory], + t.Generic[AnyBaseModelViewSetTestCase, AnyBaseAPIRequestFactory], ): """ An API client that helps make requests to a model view set and assert their responses. """ - _test_case: "ModelViewSetTestCase[RequestUser, AnyModel]" - @property def _model_class(self): """Shortcut to get model class.""" @@ -197,6 +204,7 @@ def retrieve( return response + # pylint: disable-next=too-many-arguments def list( self, models: t.Collection[AnyModel], @@ -258,6 +266,7 @@ def _make_assertions(response_json: JsonDict): # Partial Update (HTTP PATCH) # -------------------------------------------------------------------------- + # pylint: disable-next=too-many-arguments def _assert_update( self, model: AnyModel, @@ -271,6 +280,7 @@ def _assert_update( model, json_model, action, request_method, contains_subset=partial ) + # pylint: disable-next=too-many-arguments def partial_update( self, model: AnyModel, @@ -321,6 +331,7 @@ def partial_update( return response + # pylint: disable-next=too-many-arguments def bulk_partial_update( self, models: t.Union[t.List[AnyModel], QuerySet[AnyModel]], @@ -381,6 +392,7 @@ def _make_assertions(json_models: t.List[JsonDict]): # Update (HTTP PUT) # -------------------------------------------------------------------------- + # pylint: disable-next=too-many-arguments def update( self, model: AnyModel, @@ -431,6 +443,7 @@ def update( return response + # pylint: disable-next=too-many-arguments def bulk_update( self, models: t.Union[t.List[AnyModel], QuerySet[AnyModel]], @@ -605,4 +618,16 @@ def cron_job(self, action: str): return response -# pylint: enable=no-member +# pylint: disable-next=too-many-ancestors +class ModelViewSetClient( # type: ignore[misc] + BaseModelViewSetClient[ + "ModelViewSetTestCase[RequestUser, AnyModel]", + APIRequestFactory[RequestUser], + ], + APIClient[RequestUser], + t.Generic[RequestUser, AnyModel], +): + """ + An API client that helps make requests to a model view set and assert their + responses. + """ From 045deefd8dc806478ab6ff1a49754c9f0b7bb494 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 11:08:26 +0000 Subject: [PATCH 38/56] import base classes --- codeforlife/tests/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codeforlife/tests/__init__.py b/codeforlife/tests/__init__.py index a6d0ef65..ad077772 100644 --- a/codeforlife/tests/__init__.py +++ b/codeforlife/tests/__init__.py @@ -18,6 +18,6 @@ BaseModelSerializerTestCase, ModelSerializerTestCase, ) -from .model_view_set import ModelViewSetTestCase -from .model_view_set_client import ModelViewSetClient +from .model_view_set import BaseModelViewSetTestCase, ModelViewSetTestCase +from .model_view_set_client import BaseModelViewSetClient, ModelViewSetClient from .test import Client, TestCase From 3b0409e572e5eca8246fe6ffbf0f7bfcc855edd1 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 11:38:17 +0000 Subject: [PATCH 39/56] fix: session def --- codeforlife/models/abstract_base_session.py | 32 ++++++++++++++++----- codeforlife/user/models/session.py | 9 +----- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/codeforlife/models/abstract_base_session.py b/codeforlife/models/abstract_base_session.py index a7522753..d45dc543 100644 --- a/codeforlife/models/abstract_base_session.py +++ b/codeforlife/models/abstract_base_session.py @@ -5,7 +5,6 @@ import typing as t -from django.contrib.auth import get_user_model from django.contrib.sessions.base_session import ( AbstractBaseSession as _AbstractBaseSession, ) @@ -15,6 +14,7 @@ from .abstract_base_user import AbstractBaseUser +# pylint: disable=duplicate-code if t.TYPE_CHECKING: from django_stubs_ext.db.models import TypedModelMeta @@ -22,6 +22,9 @@ else: TypedModelMeta = object +AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) +# pylint: enable=duplicate-code + class AbstractBaseSession(_AbstractBaseSession): """ @@ -32,12 +35,6 @@ class AbstractBaseSession(_AbstractBaseSession): pk: str # type: ignore[assignment] user_id: int - user = models.OneToOneField( - t.cast(t.Type[AbstractBaseUser], get_user_model()), - null=True, - blank=True, - on_delete=models.CASCADE, - ) # pylint: disable-next=missing-class-docstring,too-few-public-methods class Meta(TypedModelMeta): @@ -58,3 +55,24 @@ def store(self): @classmethod def get_session_store_class(cls) -> t.Type["BaseSessionStore"]: raise NotImplementedError + + @staticmethod + def init_user_field(user_class: t.Type[AnyAbstractBaseUser]): + """Initializes the user field that relates a session to a user. + + Example: + class Session(AbstractBaseSession): + user = AbstractBaseSession.init_user_field(User) + + Args: + user_class: The user model to associate sessions to. + + Returns: + A one-to-one field that relates to the provided user model. + """ + return models.OneToOneField( + user_class, + null=True, + blank=True, + on_delete=models.CASCADE, + ) diff --git a/codeforlife/user/models/session.py b/codeforlife/user/models/session.py index a6df0f95..a2819de2 100644 --- a/codeforlife/user/models/session.py +++ b/codeforlife/user/models/session.py @@ -5,7 +5,6 @@ import typing as t -from django.db import models from django.db.models.query import QuerySet from ...models import AbstractBaseSession, BaseSessionStore @@ -23,13 +22,7 @@ class Session(AbstractBaseSession): auth_factors: QuerySet["SessionAuthFactor"] - # TODO: remove in new schema - user = models.OneToOneField( # type: ignore[assignment] - User, - null=True, - blank=True, - on_delete=models.CASCADE, - ) + user = AbstractBaseSession.init_user_field(User) @classmethod def get_session_store_class(cls): From 40e7be7290328647464f492d26a08fa7a39dc3e6 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 11:45:48 +0000 Subject: [PATCH 40/56] mypy ignore --- codeforlife/models/base_session_store.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/codeforlife/models/base_session_store.py b/codeforlife/models/base_session_store.py index 81300194..332a5de2 100644 --- a/codeforlife/models/base_session_store.py +++ b/codeforlife/models/base_session_store.py @@ -58,7 +58,7 @@ def associate_session_to_user( user_id: The user to associate. """ objects = self.get_user_class().objects # type: ignore[attr-defined] - session.user = objects.get(id=user_id) + session.user = objects.get(id=user_id) # type: ignore[attr-defined] def create_model_instance(self, data): try: @@ -70,7 +70,9 @@ def create_model_instance(self, data): model_class = self.get_model_class() try: - session = model_class.objects.get(user_id=user_id) + session = model_class.objects.get( + user_id=user_id, # type: ignore[misc] + ) except model_class.DoesNotExist: session = model_class.objects.get(session_key=self.session_key) self.associate_session_to_user( From b633a24ec881a51988de5e4bac3b37df320f8082 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 12:20:21 +0000 Subject: [PATCH 41/56] remove id field --- codeforlife/models/abstract_base_user.py | 1 - 1 file changed, 1 deletion(-) diff --git a/codeforlife/models/abstract_base_user.py b/codeforlife/models/abstract_base_user.py index 3323fe82..9d9ff4b4 100644 --- a/codeforlife/models/abstract_base_user.py +++ b/codeforlife/models/abstract_base_user.py @@ -22,7 +22,6 @@ class AbstractBaseUser(_AbstractBaseUser): https://docs.djangoproject.com/en/3.2/topics/auth/customizing/#using-a-custom-user-model-when-starting-a-project """ - id: int pk: int session: "AbstractBaseSession" From 62266099e0375e2a2d1989bf9f72d0a76fc5bc58 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 13:09:10 +0000 Subject: [PATCH 42/56] abstract is authenticated --- codeforlife/models/abstract_base_user.py | 35 ++++++++++++++++++++++++ codeforlife/user/models/user.py | 15 ++++------ 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/codeforlife/models/abstract_base_user.py b/codeforlife/models/abstract_base_user.py index 9d9ff4b4..70aef6a2 100644 --- a/codeforlife/models/abstract_base_user.py +++ b/codeforlife/models/abstract_base_user.py @@ -3,8 +3,12 @@ Created on 06/11/2024 at 16:38:15(+00:00). """ +import sys import typing as t +from functools import cached_property +from django.apps import apps +from django.conf import settings from django.contrib.auth.models import AbstractBaseUser as _AbstractBaseUser from django.utils.translation import gettext_lazy as _ @@ -30,3 +34,34 @@ class Meta(TypedModelMeta): abstract = True verbose_name = _("user") verbose_name_plural = _("users") + + @cached_property + def _session_class(self): + return t.cast( + t.Type["AbstractBaseSession"], + apps.get_model( + app_label=( + t.cast(str, settings.SESSION_ENGINE) + .lower() + .removesuffix(".models.session") + .split(".")[-1] + ), + model_name="session", + ), + ) + + @property + def is_authenticated(self): + """A flag designating if this contributor has authenticated.""" + # Avoid initial migration error where session table is not created yet + if ( + sys.argv + and "manage.py" in sys.argv[0] + and "runserver" not in sys.argv + ): + return True + + try: + return self.is_active and not self.session.is_expired + except self._session_class.DoesNotExist: + return False diff --git a/codeforlife/user/models/user.py b/codeforlife/user/models/user.py index 9d9a0ad2..fe75ee25 100644 --- a/codeforlife/user/models/user.py +++ b/codeforlife/user/models/user.py @@ -62,16 +62,11 @@ class Meta(TypedModelMeta): @property def is_authenticated(self): - """ - Check if the user has any pending auth factors. - """ - # pylint: disable-next=import-outside-toplevel - from .session import Session - - try: - return self.is_active and not self.session.auth_factors.exists() - except Session.DoesNotExist: - return False + return ( + not self.session.auth_factors.exists() + if super().is_authenticated + else False + ) @property def student(self) -> t.Optional["Student"]: From 691d3721e8e278312e82eebcfe19f31a4ebe6c35 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 13:21:39 +0000 Subject: [PATCH 43/56] fix: comment out check --- codeforlife/models/abstract_base_user.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/codeforlife/models/abstract_base_user.py b/codeforlife/models/abstract_base_user.py index 70aef6a2..8c74a267 100644 --- a/codeforlife/models/abstract_base_user.py +++ b/codeforlife/models/abstract_base_user.py @@ -3,7 +3,7 @@ Created on 06/11/2024 at 16:38:15(+00:00). """ -import sys +# import sys import typing as t from functools import cached_property @@ -53,13 +53,14 @@ def _session_class(self): @property def is_authenticated(self): """A flag designating if this contributor has authenticated.""" + # TODO: delete if not needed. # Avoid initial migration error where session table is not created yet - if ( - sys.argv - and "manage.py" in sys.argv[0] - and "runserver" not in sys.argv - ): - return True + # if ( + # sys.argv + # and "manage.py" in sys.argv[0] + # and "runserver" not in sys.argv + # ): + # return True try: return self.is_active and not self.session.is_expired From 58a5d585e93f317d465f1037586e8b43929c733d Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 13:26:10 +0000 Subject: [PATCH 44/56] delete unnecessary code --- codeforlife/models/abstract_base_user.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/codeforlife/models/abstract_base_user.py b/codeforlife/models/abstract_base_user.py index 8c74a267..5b2305f4 100644 --- a/codeforlife/models/abstract_base_user.py +++ b/codeforlife/models/abstract_base_user.py @@ -3,7 +3,6 @@ Created on 06/11/2024 at 16:38:15(+00:00). """ -# import sys import typing as t from functools import cached_property @@ -53,15 +52,6 @@ def _session_class(self): @property def is_authenticated(self): """A flag designating if this contributor has authenticated.""" - # TODO: delete if not needed. - # Avoid initial migration error where session table is not created yet - # if ( - # sys.argv - # and "manage.py" in sys.argv[0] - # and "runserver" not in sys.argv - # ): - # return True - try: return self.is_active and not self.session.is_expired except self._session_class.DoesNotExist: From 687d58797d7e46c6fb039c02172b90fada557771 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 13:54:34 +0000 Subject: [PATCH 45/56] fix pre setup --- codeforlife/tests/api.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/codeforlife/tests/api.py b/codeforlife/tests/api.py index 2b81f0d0..02420746 100644 --- a/codeforlife/tests/api.py +++ b/codeforlife/tests/api.py @@ -26,6 +26,11 @@ class BaseAPITestCase(TestCase, t.Generic[AnyBaseAPIClient]): client: AnyBaseAPIClient client_class: t.Type[AnyBaseAPIClient] + def _pre_setup(self): + # pylint: disable-next=protected-access + self.client_class._test_case = self + super()._pre_setup() # type: ignore[misc] + class APITestCase( BaseAPITestCase[APIClient[RequestUser]], @@ -54,10 +59,10 @@ class _Client( self.get_request_user_class() ] ): - _test_case = self + pass return _Client def _pre_setup(self): self.client_class = self._get_client_class() - super()._pre_setup() # type: ignore[misc] + super()._pre_setup() From f0928328c36454978fb612400128321f8394481e Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 14:01:22 +0000 Subject: [PATCH 46/56] disable no-member --- codeforlife/tests/model_view_set_client.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/codeforlife/tests/model_view_set_client.py b/codeforlife/tests/model_view_set_client.py index fa97765b..d430f46b 100644 --- a/codeforlife/tests/model_view_set_client.py +++ b/codeforlife/tests/model_view_set_client.py @@ -36,6 +36,8 @@ ) # pylint: enable=duplicate-code +# pylint: disable=no-member + # pylint: disable-next=too-many-ancestors class BaseModelViewSetClient( @@ -618,6 +620,9 @@ def cron_job(self, action: str): return response +# pylint: enable=no-member + + # pylint: disable-next=too-many-ancestors class ModelViewSetClient( # type: ignore[misc] BaseModelViewSetClient[ From b7e603eea5eb19031171eaff79cf0c264533f2f7 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 14:22:49 +0000 Subject: [PATCH 47/56] model serializer type arg --- codeforlife/views/model.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/codeforlife/views/model.py b/codeforlife/views/model.py index 93ec7f5c..3f6dd195 100644 --- a/codeforlife/views/model.py +++ b/codeforlife/views/model.py @@ -31,6 +31,9 @@ from ..user.models import User RequestUser = t.TypeVar("RequestUser", bound=User) + AnyBaseModelSerializer = t.TypeVar( + "AnyBaseModelSerializer", bound=BaseModelSerializer + ) # NOTE: This raises an error during runtime. # pylint: disable-next=too-few-public-methods @@ -39,6 +42,7 @@ class _ModelViewSet(DrfModelViewSet[AnyModel], t.Generic[AnyModel]): else: RequestUser = t.TypeVar("RequestUser") + AnyBaseModelSerializer = t.TypeVar("AnyBaseModelSerializer") # pylint: disable-next=too-many-ancestors class _ModelViewSet(DrfModelViewSet, t.Generic[AnyModel]): @@ -54,13 +58,11 @@ class _ModelViewSet(DrfModelViewSet, t.Generic[AnyModel]): class BaseModelViewSet( BaseAPIView[AnyBaseRequest], _ModelViewSet[AnyModel], - t.Generic[AnyBaseRequest, AnyModel], + t.Generic[AnyBaseRequest, AnyBaseModelSerializer, AnyModel], ): """Base model view set for all model view sets.""" - serializer_class: t.Optional[ - t.Type["BaseModelSerializer[AnyBaseRequest, AnyModel]"] - ] + serializer_class: t.Optional[t.Type[AnyBaseModelSerializer]] @classmethod def get_model_class(cls) -> t.Type[AnyModel]: @@ -160,16 +162,16 @@ def partial_update( # type: ignore[override] # pragma: no cover # pylint: disable-next=too-many-ancestors class ModelViewSet( - BaseModelViewSet[Request[RequestUser], AnyModel], + BaseModelViewSet[ + Request[RequestUser], + "ModelSerializer[RequestUser, AnyModel]", + AnyModel, + ], APIView[RequestUser], t.Generic[RequestUser, AnyModel], ): """Base model view set for all model view sets.""" - serializer_class: t.Optional[ - t.Type["ModelSerializer[RequestUser, AnyModel]"] - ] - def get_bulk_queryset(self, lookup_values: t.Collection): """Get the queryset for a bulk action. From 5848e30b152a991d8650b63a13e9f79d211f7878 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 14:32:28 +0000 Subject: [PATCH 48/56] AnyBaseModelViewSet --- codeforlife/serializers/model.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/codeforlife/serializers/model.py b/codeforlife/serializers/model.py index 7b5601c2..0125ffb8 100644 --- a/codeforlife/serializers/model.py +++ b/codeforlife/serializers/model.py @@ -20,8 +20,12 @@ from ..views import BaseModelViewSet, ModelViewSet RequestUser = t.TypeVar("RequestUser", bound=User) + AnyBaseModelViewSet = t.TypeVar( + "AnyBaseModelViewSet", bound=BaseModelViewSet + ) else: RequestUser = t.TypeVar("RequestUser") + AnyBaseModelViewSet = t.TypeVar("AnyBaseModelViewSet") AnyModel = t.TypeVar("AnyModel", bound=Model) AnyBaseRequest = t.TypeVar("AnyBaseRequest", bound=BaseRequest) @@ -31,12 +35,12 @@ class BaseModelSerializer( BaseSerializer[AnyBaseRequest], _ModelSerializer[AnyModel], - t.Generic[AnyBaseRequest, AnyModel], + t.Generic[AnyBaseRequest, AnyBaseModelViewSet, AnyModel], ): """Base model serializer for all model serializers.""" instance: t.Optional[AnyModel] - view: "BaseModelViewSet[AnyBaseRequest, AnyModel]" + view: AnyBaseModelViewSet @property def non_none_instance(self): @@ -60,9 +64,11 @@ def to_representation(self, instance: AnyModel) -> DataDict: class ModelSerializer( - BaseModelSerializer[Request[RequestUser], AnyModel], + BaseModelSerializer[ + Request[RequestUser], + "ModelViewSet[RequestUser, AnyModel]", + AnyModel, + ], t.Generic[RequestUser, AnyModel], ): """Base model serializer for all model serializers.""" - - view: "ModelViewSet[RequestUser, AnyModel]" # type: ignore[assignment] From 68df52259a4d9a8aeeeccc088947b27d0667e04b Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 14:37:30 +0000 Subject: [PATCH 49/56] AnyBaseModelViewSet --- codeforlife/serializers/model_list.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/codeforlife/serializers/model_list.py b/codeforlife/serializers/model_list.py index 18249adb..99be97ac 100644 --- a/codeforlife/serializers/model_list.py +++ b/codeforlife/serializers/model_list.py @@ -21,8 +21,12 @@ from ..views import BaseModelViewSet, ModelViewSet RequestUser = t.TypeVar("RequestUser", bound=User) + AnyBaseModelViewSet = t.TypeVar( + "AnyBaseModelViewSet", bound=BaseModelViewSet + ) else: RequestUser = t.TypeVar("RequestUser") + AnyBaseModelViewSet = t.TypeVar("AnyBaseModelViewSet") AnyModel = t.TypeVar("AnyModel", bound=Model) AnyBaseRequest = t.TypeVar("AnyBaseRequest", bound=BaseRequest) @@ -36,7 +40,7 @@ class BaseModelListSerializer( BaseSerializer[AnyBaseRequest], _ListSerializer[t.List[AnyModel]], - t.Generic[AnyBaseRequest, AnyModel], + t.Generic[AnyBaseRequest, AnyBaseModelViewSet, AnyModel], ): """Base model list serializer for all model list serializers. @@ -57,7 +61,7 @@ class Meta: instance: t.Optional[t.List[AnyModel]] batch_size: t.Optional[int] = None - view: "BaseModelViewSet[AnyBaseRequest, AnyModel]" + view: AnyBaseModelViewSet @property def non_none_instance(self): @@ -185,7 +189,11 @@ def to_representation(self, instance: t.List[AnyModel]) -> t.List[DataDict]: class ModelListSerializer( - BaseModelListSerializer[Request[RequestUser], AnyModel], + BaseModelListSerializer[ + Request[RequestUser], + "ModelViewSet[RequestUser, AnyModel]", + AnyModel, + ], t.Generic[RequestUser, AnyModel], ): """Base model list serializer for all model list serializers. @@ -204,5 +212,3 @@ class Meta: model = User list_serializer_class = UserListSerializer """ - - view: "ModelViewSet[RequestUser, AnyModel]" # type: ignore[assignment] From 623b1a2a334a5f60cd0f516fe7633fc2a28d3345 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 14:42:40 +0000 Subject: [PATCH 50/56] fix type hints --- codeforlife/tests/model_serializer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codeforlife/tests/model_serializer.py b/codeforlife/tests/model_serializer.py index 984652a5..84f52052 100644 --- a/codeforlife/tests/model_serializer.py +++ b/codeforlife/tests/model_serializer.py @@ -141,7 +141,7 @@ def _assert_many( new_data: t.Optional[t.List[DataDict]], non_model_fields: t.Optional[NonModelFields], get_models: t.Callable[ - [BaseModelListSerializer[t.Any, AnyModel], t.List[DataDict]], + [BaseModelListSerializer[t.Any, t.Any, AnyModel], t.List[DataDict]], t.List[AnyModel], ], *args, @@ -154,7 +154,7 @@ def _assert_many( assert len(new_data) == len(validated_data) kwargs.pop("many", None) # many must be True - serializer: BaseModelListSerializer[t.Any, AnyModel] = ( + serializer: BaseModelListSerializer[t.Any, t.Any, AnyModel] = ( self._init_model_serializer(*args, **kwargs, many=True) ) From 9fc68d3c8a699f7edae8ae3e88b480e54ac5787f Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 15:21:59 +0000 Subject: [PATCH 51/56] base login view and form --- codeforlife/forms.py | 83 +++++++++++++++++++++++++ codeforlife/views/__init__.py | 1 + codeforlife/views/base_login.py | 106 ++++++++++++++++++++++++++++++++ 3 files changed, 190 insertions(+) create mode 100644 codeforlife/forms.py create mode 100644 codeforlife/views/base_login.py diff --git a/codeforlife/forms.py b/codeforlife/forms.py new file mode 100644 index 00000000..12f3101b --- /dev/null +++ b/codeforlife/forms.py @@ -0,0 +1,83 @@ +""" +© Ocado Group +Created on 07/11/2024 at 15:08:33(+00:00). +""" + +import typing as t + +from django import forms +from django.contrib.auth import authenticate +from django.core.exceptions import ValidationError +from django.core.handlers.wsgi import WSGIRequest + +from .models import AbstractBaseUser + +AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) + + +class BaseLoginForm(forms.Form, t.Generic[AnyAbstractBaseUser]): + """Base login form that all other login forms must inherit.""" + + user: AnyAbstractBaseUser + + @classmethod + def get_user_class(cls) -> t.Type[AnyAbstractBaseUser]: + """Get the user class.""" + # pylint: disable-next=no-member + return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] + 0 + ] + + def __init__(self, request: WSGIRequest, *args, **kwargs): + self.request = request + super().__init__(*args, **kwargs) + + def clean(self): + """Authenticates a user. + + Raises: + ValidationError: If there are form errors. + ValidationError: If the user's credentials were incorrect. + ValidationError: If the user's account is deactivated. + + Returns: + The cleaned form data. + """ + + if self.errors: + raise ValidationError( + "Found form errors. Skipping authentication.", + code="form_errors", + ) + + user = authenticate( + self.request, + **{key: self.cleaned_data[key] for key in self.fields.keys()} + ) + if user is None: + raise ValidationError( + self.get_invalid_login_error_message(), + code="invalid_login", + ) + if not isinstance(user, self.get_user_class()): + raise ValidationError( + "Incorrect user class.", + code="incorrect_user_class", + ) + if not user.is_active: + raise ValidationError( + "User is not active", + code="user_not_active", + ) + + self.user = user + + return self.cleaned_data + + def get_invalid_login_error_message(self) -> str: + """Returns the error message if the user failed to login. + + Raises: + NotImplementedError: If message is not set. + """ + raise NotImplementedError() diff --git a/codeforlife/views/__init__.py b/codeforlife/views/__init__.py index 33ecf266..3abde6bf 100644 --- a/codeforlife/views/__init__.py +++ b/codeforlife/views/__init__.py @@ -4,6 +4,7 @@ """ from .api import APIView, BaseAPIView +from .base_login import BaseLoginView from .common import CsrfCookieView, LogoutView from .decorators import action, cron_job from .model import BaseModelViewSet, ModelViewSet diff --git a/codeforlife/views/base_login.py b/codeforlife/views/base_login.py new file mode 100644 index 00000000..76eb1a49 --- /dev/null +++ b/codeforlife/views/base_login.py @@ -0,0 +1,106 @@ +""" +© Ocado Group +Created on 07/11/2024 at 14:58:38(+00:00). +""" + +import json +import typing as t +from urllib.parse import quote_plus + +from django.conf import settings +from django.contrib.auth import login +from django.contrib.auth.views import LoginView +from django.http import JsonResponse +from rest_framework import status + +from ..forms import BaseLoginForm +from ..models import AbstractBaseUser +from ..request import BaseHttpRequest +from ..types import JsonDict + +AnyBaseHttpRequest = t.TypeVar("AnyBaseHttpRequest", bound=BaseHttpRequest) +AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) + + +class BaseLoginView( + LoginView, + t.Generic[AnyBaseHttpRequest, AnyAbstractBaseUser], +): + """ + Extends Django's login view to allow a user to log in using one of the + approved forms. + + WARNING: It's critical that to inherit Django's login view as it implements + industry standard security measures that a login view should have. + """ + + request: AnyBaseHttpRequest + + def get_form_kwargs(self): + form_kwargs = super().get_form_kwargs() + form_kwargs["data"] = json.loads(self.request.body) + + return form_kwargs + + def get_session_metadata(self, user: AnyAbstractBaseUser) -> JsonDict: + """Get the session's metadata. + + Args: + user: The user the session is for. + + Raises: + NotImplementedError: If this method is not implemented. + + Returns: + A JSON-serializable dict containing the session's metadata. + """ + raise NotImplementedError + + def form_valid( + self, form: BaseLoginForm[AnyAbstractBaseUser] # type: ignore + ): + user = form.user + + # Clear expired sessions. + self.request.session.clear_expired(user_id=user.pk) + + # Create session (without data). + login(self.request, user) + + # Save session (with data). + self.request.session.save() + + # Get session metadata. + session_metadata = self.get_session_metadata(user) + + # Return session metadata in response and a non-HTTP-only cookie. + response = JsonResponse(session_metadata) + response.set_cookie( + key=settings.SESSION_METADATA_COOKIE_NAME, + value=quote_plus( + json.dumps( + session_metadata, + separators=(",", ":"), + indent=None, + ) + ), + max_age=( + None + if settings.SESSION_EXPIRE_AT_BROWSER_CLOSE + else settings.SESSION_COOKIE_AGE + ), + secure=settings.SESSION_COOKIE_SECURE, + samesite=t.cast( + t.Optional[t.Literal["Lax", "Strict", "None", False]], + settings.SESSION_COOKIE_SAMESITE, + ), + domain=settings.SESSION_COOKIE_DOMAIN, + httponly=False, + ) + + return response + + def form_invalid( + self, form: BaseLoginForm[AnyAbstractBaseUser] # type: ignore + ): + return JsonResponse(form.errors, status=status.HTTP_400_BAD_REQUEST) From 882de7316c5c5054b7ca02d8351704a8b298d398 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 15:42:40 +0000 Subject: [PATCH 52/56] fix: import --- codeforlife/tests/api_client.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/codeforlife/tests/api_client.py b/codeforlife/tests/api_client.py index b2113504..4835c7c9 100644 --- a/codeforlife/tests/api_client.py +++ b/codeforlife/tests/api_client.py @@ -373,6 +373,9 @@ def login(self, **credentials): Returns: The user. """ + # pylint: disable-next=import-outside-toplevel + from ..user.models import User + return self._login_user_type(User, **credentials) def login_teacher(self, email: str, password: str = "password"): From 4675ee289a551205abbb0cf8d94e1a1eb202b1f2 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 17:27:05 +0000 Subject: [PATCH 53/56] get arg helper --- codeforlife/forms.py | 6 ++---- codeforlife/models/base_session_store.py | 12 ++++-------- codeforlife/serializers/model_list.py | 7 ++----- codeforlife/tests/api.py | 6 ++---- codeforlife/tests/api_client.py | 7 ++----- codeforlife/tests/api_request_factory.py | 6 ++---- codeforlife/tests/model.py | 6 ++---- codeforlife/tests/model_serializer.py | 12 +++--------- codeforlife/tests/model_view_set.py | 16 ++++------------ codeforlife/types.py | 16 ++++++++++++++++ codeforlife/views/api.py | 6 ++---- codeforlife/views/model.py | 7 ++----- 12 files changed, 43 insertions(+), 64 deletions(-) diff --git a/codeforlife/forms.py b/codeforlife/forms.py index 12f3101b..b93046a7 100644 --- a/codeforlife/forms.py +++ b/codeforlife/forms.py @@ -11,6 +11,7 @@ from django.core.handlers.wsgi import WSGIRequest from .models import AbstractBaseUser +from .types import get_arg AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) @@ -23,10 +24,7 @@ class BaseLoginForm(forms.Form, t.Generic[AnyAbstractBaseUser]): @classmethod def get_user_class(cls) -> t.Type[AnyAbstractBaseUser]: """Get the user class.""" - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) def __init__(self, request: WSGIRequest, *args, **kwargs): self.request = request diff --git a/codeforlife/models/base_session_store.py b/codeforlife/models/base_session_store.py index 332a5de2..b6e5e648 100644 --- a/codeforlife/models/base_session_store.py +++ b/codeforlife/models/base_session_store.py @@ -9,6 +9,8 @@ from django.contrib.sessions.backends.db import SessionStore from django.utils import timezone +from ..types import get_arg + if t.TYPE_CHECKING: from .abstract_base_session import AbstractBaseSession from .abstract_base_user import AbstractBaseUser @@ -35,18 +37,12 @@ class BaseSessionStore( @classmethod def get_model_class(cls) -> t.Type[AnyAbstractBaseSession]: - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) @classmethod def get_user_class(cls) -> t.Type[AnyAbstractBaseUser]: """Get the user class.""" - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 1 - ] + return get_arg(cls, 1) def associate_session_to_user( self, session: AnyAbstractBaseSession, user_id: int diff --git a/codeforlife/serializers/model_list.py b/codeforlife/serializers/model_list.py index 99be97ac..ecff4273 100644 --- a/codeforlife/serializers/model_list.py +++ b/codeforlife/serializers/model_list.py @@ -12,7 +12,7 @@ from rest_framework.serializers import ValidationError as _ValidationError from ..request import BaseRequest, Request -from ..types import DataDict, OrderedDataDict +from ..types import DataDict, OrderedDataDict, get_arg from .base import BaseSerializer # pylint: disable=duplicate-code @@ -75,10 +75,7 @@ def get_model_class(cls) -> t.Type[AnyModel]: Returns: The model view set's class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) def __init__(self, *args, **kwargs): instance = args[0] if args else kwargs.pop("instance", None) diff --git a/codeforlife/tests/api.py b/codeforlife/tests/api.py index 02420746..ac54e987 100644 --- a/codeforlife/tests/api.py +++ b/codeforlife/tests/api.py @@ -5,6 +5,7 @@ import typing as t +from ..types import get_arg from .api_client import APIClient, BaseAPIClient from .test import TestCase @@ -47,10 +48,7 @@ def get_request_user_class(cls) -> t.Type[RequestUser]: Returns: The request's user class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) def _get_client_class(self): # pylint: disable-next=too-few-public-methods diff --git a/codeforlife/tests/api_client.py b/codeforlife/tests/api_client.py index 4835c7c9..3dfa4450 100644 --- a/codeforlife/tests/api_client.py +++ b/codeforlife/tests/api_client.py @@ -12,7 +12,7 @@ from rest_framework.response import Response from rest_framework.test import APIClient as _APIClient -from ..types import DataDict, JsonDict +from ..types import DataDict, JsonDict, get_arg from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory # pylint: disable=duplicate-code @@ -329,10 +329,7 @@ def get_request_user_class(cls) -> t.Type[RequestUser]: Returns: The request's user class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) # -------------------------------------------------------------------------- # Login Helpers diff --git a/codeforlife/tests/api_request_factory.py b/codeforlife/tests/api_request_factory.py index 0758e1a5..fbce7e56 100644 --- a/codeforlife/tests/api_request_factory.py +++ b/codeforlife/tests/api_request_factory.py @@ -16,6 +16,7 @@ from rest_framework.test import APIRequestFactory as _APIRequestFactory from ..request import BaseRequest, Request +from ..types import get_arg # pylint: disable=duplicate-code if t.TYPE_CHECKING: @@ -247,10 +248,7 @@ def get_user_class(cls) -> t.Type[AnyUser]: Returns: The user class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) def _init_request(self, wsgi_request): return Request[AnyUser]( diff --git a/codeforlife/tests/model.py b/codeforlife/tests/model.py index e04eaa00..a61edb1c 100644 --- a/codeforlife/tests/model.py +++ b/codeforlife/tests/model.py @@ -10,6 +10,7 @@ from django.db.models import Model from django.db.utils import IntegrityError +from ..types import get_arg from .test import TestCase AnyModel = t.TypeVar("AnyModel", bound=Model) @@ -25,10 +26,7 @@ def get_model_class(cls) -> t.Type[AnyModel]: Returns: The model's class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) def assert_raises_integrity_error(self, *args, **kwargs): """Assert the code block raises an integrity error. diff --git a/codeforlife/tests/model_serializer.py b/codeforlife/tests/model_serializer.py index 84f52052..1f41dec1 100644 --- a/codeforlife/tests/model_serializer.py +++ b/codeforlife/tests/model_serializer.py @@ -18,7 +18,7 @@ BaseModelSerializer, ModelSerializer, ) -from ..types import DataDict +from ..types import DataDict, get_arg from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory from .test import TestCase @@ -397,10 +397,7 @@ def get_request_user_class(cls) -> t.Type[AnyModel]: Returns: The model view set's class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) @classmethod def get_model_class(cls) -> t.Type[AnyModel]: @@ -409,10 +406,7 @@ def get_model_class(cls) -> t.Type[AnyModel]: Returns: The model view set's class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 1 - ] + return get_arg(cls, 1) @classmethod def _initialize_request_factory(cls, **kwargs): diff --git a/codeforlife/tests/model_view_set.py b/codeforlife/tests/model_view_set.py index d52b344e..2304a926 100644 --- a/codeforlife/tests/model_view_set.py +++ b/codeforlife/tests/model_view_set.py @@ -16,7 +16,7 @@ from ..models import AbstractBaseUser from ..permissions import Permission from ..serializers import BaseSerializer -from ..types import DataDict, JsonDict, KwArgs +from ..types import DataDict, JsonDict, KwArgs, get_arg from ..views import BaseModelViewSet, ModelViewSet from .api import APITestCase, BaseAPITestCase from .model_view_set_client import BaseModelViewSetClient, ModelViewSetClient @@ -65,9 +65,7 @@ def get_model_class(cls) -> t.Type[AnyModel]: The model view set's class. """ # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 2 - ] + return get_arg(cls, 2) @classmethod def setUpClass(cls): @@ -272,10 +270,7 @@ def get_request_user_class(cls) -> t.Type[RequestUser]: Returns: The request's user class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) @classmethod def get_model_class(cls) -> t.Type[AnyModel]: @@ -284,10 +279,7 @@ def get_model_class(cls) -> t.Type[AnyModel]: Returns: The model view set's class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 1 - ] + return get_arg(cls, 1) def _get_client_class(self): # TODO: unpack type args in index after moving to python 3.11 diff --git a/codeforlife/types.py b/codeforlife/types.py index 2d151263..a2c85b36 100644 --- a/codeforlife/types.py +++ b/codeforlife/types.py @@ -7,6 +7,8 @@ import typing as t +T = t.TypeVar("T") + Args = t.Tuple[t.Any, ...] KwArgs = t.Dict[str, t.Any] @@ -16,3 +18,17 @@ DataDict = t.Dict[str, t.Any] OrderedDataDict = t.OrderedDict[str, t.Any] + + +def get_arg(cls: t.Type[t.Any], index: int, orig_base: int = 0): + """Get a type arg from a class. + + Args: + cls: The class to get the type arg from. + index: The index of the type arg to get. + orig_base: The base class to get the type arg from. + + Returns: + The type arg from the class. + """ + return t.get_args(cls.__orig_bases__[orig_base])[index] diff --git a/codeforlife/views/api.py b/codeforlife/views/api.py index ff143590..1b95ead1 100644 --- a/codeforlife/views/api.py +++ b/codeforlife/views/api.py @@ -9,6 +9,7 @@ from rest_framework.views import APIView as _APIView from ..request import BaseRequest, Request +from ..types import get_arg # pylint: disable=duplicate-code if t.TYPE_CHECKING: @@ -55,10 +56,7 @@ def get_request_user_class(cls) -> t.Type[RequestUser]: Returns: The request's user class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) def _initialize_request(self, request, **kwargs): kwargs["user_class"] = self.get_request_user_class() diff --git a/codeforlife/views/model.py b/codeforlife/views/model.py index 3f6dd195..1ac326e5 100644 --- a/codeforlife/views/model.py +++ b/codeforlife/views/model.py @@ -15,7 +15,7 @@ from ..permissions import Permission from ..request import BaseRequest, Request -from ..types import KwArgs +from ..types import KwArgs, get_arg from .api import APIView, BaseAPIView from .decorators import action @@ -71,10 +71,7 @@ def get_model_class(cls) -> t.Type[AnyModel]: Returns: The model view set's class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) @cached_property def lookup_field_name(self): From 655604e6c749ace4866e12abd76d97962f017e66 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 17:29:36 +0000 Subject: [PATCH 54/56] delete unused var --- codeforlife/types.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/codeforlife/types.py b/codeforlife/types.py index a2c85b36..68524a19 100644 --- a/codeforlife/types.py +++ b/codeforlife/types.py @@ -7,8 +7,6 @@ import typing as t -T = t.TypeVar("T") - Args = t.Tuple[t.Any, ...] KwArgs = t.Dict[str, t.Any] From 3a7c0faec937c404e8ff8f608aa547abf27467ce Mon Sep 17 00:00:00 2001 From: SKairinos Date: Mon, 11 Nov 2024 13:12:45 +0000 Subject: [PATCH 55/56] migrate on app startup --- codeforlife/app.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/codeforlife/app.py b/codeforlife/app.py index 963dd5a7..b46cf1fd 100644 --- a/codeforlife/app.py +++ b/codeforlife/app.py @@ -6,6 +6,7 @@ import multiprocessing import typing as t +from django.core.management import call_command from gunicorn.app.base import BaseApplication # type: ignore[import-untyped] @@ -19,6 +20,8 @@ class StandaloneApplication(BaseApplication): """ def __init__(self, app: t.Callable): + call_command("migrate", interactive=False) + self.options = { "bind": "0.0.0.0:8080", # https://docs.gunicorn.org/en/stable/design.html#how-many-workers From 19c790b823a68eb464831197d679f5b893955624 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Wed, 13 Nov 2024 10:00:59 +0000 Subject: [PATCH 56/56] feedback --- codeforlife/user/models/user.py | 1 + 1 file changed, 1 insertion(+) diff --git a/codeforlife/user/models/user.py b/codeforlife/user/models/user.py index fe75ee25..93a557d8 100644 --- a/codeforlife/user/models/user.py +++ b/codeforlife/user/models/user.py @@ -64,6 +64,7 @@ class Meta(TypedModelMeta): def is_authenticated(self): return ( not self.session.auth_factors.exists() + and self.userprofile.is_verified if super().is_authenticated else False )