-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* avoid importing models * fix issues * fix type imports * fix: type as vars * fix: type imports * fix: imports * fix: imports * base request * fix: types * ignore duplicate code * disable duplicate code * split request objects * session param * disable duplicate code * fix: abstract api request factory * import base api request factory * split model list serializer * abstract model view and serializer * abstract model list * import BaseModelListSerializer * disable missing-function-docstring * init request * fix: init request * abstract model serializer test case * fix types * fix linting issues * split code * fix: abstract api test case and client * # pylint: disable-next=too-many-ancestors * split code * abstract user and session * fix: type hints * fix types * fix types * disable too-many-ancestors * fix linting * abstract model view set test case and client * import base classes * fix: session def * mypy ignore * remove id field * abstract is authenticated * fix: comment out check * delete unnecessary code * fix pre setup * disable no-member * model serializer type arg * AnyBaseModelViewSet * AnyBaseModelViewSet * fix type hints * base login view and form * fix: import * get arg helper * delete unused var * migrate on app startup * feedback
- Loading branch information
Showing
34 changed files
with
2,511 additions
and
1,564 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
""" | ||
© Ocado Group | ||
Created on 07/11/2024 at 15:08:33(+00:00). | ||
""" | ||
|
||
import typing as t | ||
|
||
from django import forms | ||
from django.contrib.auth import authenticate | ||
from django.core.exceptions import ValidationError | ||
from django.core.handlers.wsgi import WSGIRequest | ||
|
||
from .models import AbstractBaseUser | ||
from .types import get_arg | ||
|
||
AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) | ||
|
||
|
||
class BaseLoginForm(forms.Form, t.Generic[AnyAbstractBaseUser]): | ||
"""Base login form that all other login forms must inherit.""" | ||
|
||
user: AnyAbstractBaseUser | ||
|
||
@classmethod | ||
def get_user_class(cls) -> t.Type[AnyAbstractBaseUser]: | ||
"""Get the user class.""" | ||
return get_arg(cls, 0) | ||
|
||
def __init__(self, request: WSGIRequest, *args, **kwargs): | ||
self.request = request | ||
super().__init__(*args, **kwargs) | ||
|
||
def clean(self): | ||
"""Authenticates a user. | ||
Raises: | ||
ValidationError: If there are form errors. | ||
ValidationError: If the user's credentials were incorrect. | ||
ValidationError: If the user's account is deactivated. | ||
Returns: | ||
The cleaned form data. | ||
""" | ||
|
||
if self.errors: | ||
raise ValidationError( | ||
"Found form errors. Skipping authentication.", | ||
code="form_errors", | ||
) | ||
|
||
user = authenticate( | ||
self.request, | ||
**{key: self.cleaned_data[key] for key in self.fields.keys()} | ||
) | ||
if user is None: | ||
raise ValidationError( | ||
self.get_invalid_login_error_message(), | ||
code="invalid_login", | ||
) | ||
if not isinstance(user, self.get_user_class()): | ||
raise ValidationError( | ||
"Incorrect user class.", | ||
code="incorrect_user_class", | ||
) | ||
if not user.is_active: | ||
raise ValidationError( | ||
"User is not active", | ||
code="user_not_active", | ||
) | ||
|
||
self.user = user | ||
|
||
return self.cleaned_data | ||
|
||
def get_invalid_login_error_message(self) -> str: | ||
"""Returns the error message if the user failed to login. | ||
Raises: | ||
NotImplementedError: If message is not set. | ||
""" | ||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
""" | ||
© Ocado Group | ||
Created on 06/11/2024 at 16:44:56(+00:00). | ||
""" | ||
|
||
import typing as t | ||
|
||
from django.contrib.sessions.base_session import ( | ||
AbstractBaseSession as _AbstractBaseSession, | ||
) | ||
from django.db import models | ||
from django.utils import timezone | ||
from django.utils.translation import gettext_lazy as _ | ||
|
||
from .abstract_base_user import AbstractBaseUser | ||
|
||
# pylint: disable=duplicate-code | ||
if t.TYPE_CHECKING: | ||
from django_stubs_ext.db.models import TypedModelMeta | ||
|
||
from .base_session_store import BaseSessionStore | ||
else: | ||
TypedModelMeta = object | ||
|
||
AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) | ||
# pylint: enable=duplicate-code | ||
|
||
|
||
class AbstractBaseSession(_AbstractBaseSession): | ||
""" | ||
Base session class to be inherited by all session classes. | ||
https://docs.djangoproject.com/en/3.2/topics/http/sessions/#example | ||
""" | ||
|
||
pk: str # type: ignore[assignment] | ||
|
||
user_id: int | ||
|
||
# pylint: disable-next=missing-class-docstring,too-few-public-methods | ||
class Meta(TypedModelMeta): | ||
abstract = True | ||
verbose_name = _("session") | ||
verbose_name_plural = _("sessions") | ||
|
||
@property | ||
def is_expired(self): | ||
"""Whether or not this session has expired.""" | ||
return self.expire_date < timezone.now() | ||
|
||
@property | ||
def store(self): | ||
"""A store instance for this session.""" | ||
return self.get_session_store_class()(self.session_key) | ||
|
||
@classmethod | ||
def get_session_store_class(cls) -> t.Type["BaseSessionStore"]: | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
def init_user_field(user_class: t.Type[AnyAbstractBaseUser]): | ||
"""Initializes the user field that relates a session to a user. | ||
Example: | ||
class Session(AbstractBaseSession): | ||
user = AbstractBaseSession.init_user_field(User) | ||
Args: | ||
user_class: The user model to associate sessions to. | ||
Returns: | ||
A one-to-one field that relates to the provided user model. | ||
""" | ||
return models.OneToOneField( | ||
user_class, | ||
null=True, | ||
blank=True, | ||
on_delete=models.CASCADE, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
""" | ||
© Ocado Group | ||
Created on 06/11/2024 at 16:38:15(+00:00). | ||
""" | ||
|
||
import typing as t | ||
from functools import cached_property | ||
|
||
from django.apps import apps | ||
from django.conf import settings | ||
from django.contrib.auth.models import AbstractBaseUser as _AbstractBaseUser | ||
from django.utils.translation import gettext_lazy as _ | ||
|
||
if t.TYPE_CHECKING: | ||
from django_stubs_ext.db.models import TypedModelMeta | ||
|
||
from .abstract_base_session import AbstractBaseSession | ||
else: | ||
TypedModelMeta = object | ||
|
||
|
||
class AbstractBaseUser(_AbstractBaseUser): | ||
""" | ||
Base user class to be inherited by all user classes. | ||
https://docs.djangoproject.com/en/3.2/topics/auth/customizing/#using-a-custom-user-model-when-starting-a-project | ||
""" | ||
|
||
pk: int | ||
session: "AbstractBaseSession" | ||
|
||
# pylint: disable-next=missing-class-docstring,too-few-public-methods | ||
class Meta(TypedModelMeta): | ||
abstract = True | ||
verbose_name = _("user") | ||
verbose_name_plural = _("users") | ||
|
||
@cached_property | ||
def _session_class(self): | ||
return t.cast( | ||
t.Type["AbstractBaseSession"], | ||
apps.get_model( | ||
app_label=( | ||
t.cast(str, settings.SESSION_ENGINE) | ||
.lower() | ||
.removesuffix(".models.session") | ||
.split(".")[-1] | ||
), | ||
model_name="session", | ||
), | ||
) | ||
|
||
@property | ||
def is_authenticated(self): | ||
"""A flag designating if this contributor has authenticated.""" | ||
try: | ||
return self.is_active and not self.session.is_expired | ||
except self._session_class.DoesNotExist: | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
""" | ||
© Ocado Group | ||
Created on 06/11/2024 at 17:31:32(+00:00). | ||
""" | ||
|
||
import typing as t | ||
|
||
from django.contrib.auth import SESSION_KEY | ||
from django.contrib.sessions.backends.db import SessionStore | ||
from django.utils import timezone | ||
|
||
from ..types import get_arg | ||
|
||
if t.TYPE_CHECKING: | ||
from .abstract_base_session import AbstractBaseSession | ||
from .abstract_base_user import AbstractBaseUser | ||
|
||
AnyAbstractBaseSession = t.TypeVar( | ||
"AnyAbstractBaseSession", bound=AbstractBaseSession | ||
) | ||
AnyAbstractBaseUser = t.TypeVar( | ||
"AnyAbstractBaseUser", bound=AbstractBaseUser | ||
) | ||
else: | ||
AnyAbstractBaseSession = t.TypeVar("AnyAbstractBaseSession") | ||
AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser") | ||
|
||
|
||
class BaseSessionStore( | ||
SessionStore, | ||
t.Generic[AnyAbstractBaseSession, AnyAbstractBaseUser], | ||
): | ||
""" | ||
Base session store class to be inherited by all session store classes. | ||
https://docs.djangoproject.com/en/3.2/topics/http/sessions/#example | ||
""" | ||
|
||
@classmethod | ||
def get_model_class(cls) -> t.Type[AnyAbstractBaseSession]: | ||
return get_arg(cls, 0) | ||
|
||
@classmethod | ||
def get_user_class(cls) -> t.Type[AnyAbstractBaseUser]: | ||
"""Get the user class.""" | ||
return get_arg(cls, 1) | ||
|
||
def associate_session_to_user( | ||
self, session: AnyAbstractBaseSession, user_id: int | ||
): | ||
"""Associate an anon session to a user. | ||
Args: | ||
session: The anon session. | ||
user_id: The user to associate. | ||
""" | ||
objects = self.get_user_class().objects # type: ignore[attr-defined] | ||
session.user = objects.get(id=user_id) # type: ignore[attr-defined] | ||
|
||
def create_model_instance(self, data): | ||
try: | ||
user_id = int(data.get(SESSION_KEY)) | ||
except (ValueError, TypeError): | ||
# Create an anon session. | ||
return super().create_model_instance(data) | ||
|
||
model_class = self.get_model_class() | ||
|
||
try: | ||
session = model_class.objects.get( | ||
user_id=user_id, # type: ignore[misc] | ||
) | ||
except model_class.DoesNotExist: | ||
session = model_class.objects.get(session_key=self.session_key) | ||
self.associate_session_to_user( | ||
t.cast(AnyAbstractBaseSession, session), user_id | ||
) | ||
|
||
session.session_data = self.encode(data) | ||
|
||
return session | ||
|
||
@classmethod | ||
def clear_expired(cls, user_id=None): | ||
session_query = cls.get_model_class().objects.filter( | ||
expire_date__lt=timezone.now() | ||
) | ||
|
||
if user_id is not None: | ||
session_query = session_query.filter(user_id=user_id) | ||
|
||
session_query.delete() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
""" | ||
© Ocado Group | ||
Created on 05/11/2024 at 14:40:32(+00:00). | ||
""" | ||
|
||
from .drf import BaseRequest, Request | ||
from .http import BaseHttpRequest, HttpRequest | ||
from .wsgi import BaseWSGIRequest, WSGIRequest |
Oops, something went wrong.