-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
210 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
""" | ||
© Ocado Group | ||
Created on 06/11/2024 at 16:44:56(+00:00). | ||
""" | ||
|
||
import typing as t | ||
|
||
from django.contrib.auth import get_user_model | ||
from django.contrib.sessions.base_session import ( | ||
AbstractBaseSession as _AbstractBaseSession, | ||
) | ||
from django.db import models | ||
from django.utils import timezone | ||
from django.utils.translation import gettext_lazy as _ | ||
|
||
from .abstract_base_user import AbstractBaseUser | ||
|
||
if t.TYPE_CHECKING: | ||
from django_stubs_ext.db.models import TypedModelMeta | ||
|
||
from .base_session_store import BaseSessionStore | ||
else: | ||
TypedModelMeta = object | ||
|
||
|
||
class AbstractBaseSession(_AbstractBaseSession): | ||
""" | ||
Base session class to be inherited by all session classes. | ||
https://docs.djangoproject.com/en/3.2/topics/http/sessions/#example | ||
""" | ||
|
||
pk: str # type: ignore[assignment] | ||
|
||
user_id: int | ||
user = models.OneToOneField( | ||
t.cast(t.Type[AbstractBaseUser], get_user_model()), | ||
null=True, | ||
blank=True, | ||
on_delete=models.CASCADE, | ||
) | ||
|
||
# pylint: disable-next=missing-class-docstring,too-few-public-methods | ||
class Meta(TypedModelMeta): | ||
abstract = True | ||
verbose_name = _("session") | ||
verbose_name_plural = _("sessions") | ||
|
||
@property | ||
def is_expired(self): | ||
"""Whether or not this session has expired.""" | ||
return self.expire_date < timezone.now() | ||
|
||
@property | ||
def store(self): | ||
"""A store instance for this session.""" | ||
return self.get_session_store_class()(self.session_key) | ||
|
||
@classmethod | ||
def get_session_store_class(cls) -> t.Type["BaseSessionStore"]: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
""" | ||
© Ocado Group | ||
Created on 06/11/2024 at 16:38:15(+00:00). | ||
""" | ||
|
||
import typing as t | ||
|
||
from django.contrib.auth.models import AbstractBaseUser as _AbstractBaseUser | ||
from django.utils.translation import gettext_lazy as _ | ||
|
||
if t.TYPE_CHECKING: | ||
from django_stubs_ext.db.models import TypedModelMeta | ||
|
||
from .abstract_base_session import AbstractBaseSession | ||
else: | ||
TypedModelMeta = object | ||
|
||
|
||
class AbstractBaseUser(_AbstractBaseUser): | ||
""" | ||
Base user class to be inherited by all user classes. | ||
https://docs.djangoproject.com/en/3.2/topics/auth/customizing/#using-a-custom-user-model-when-starting-a-project | ||
""" | ||
|
||
id: int | ||
pk: int | ||
session: "AbstractBaseSession" | ||
|
||
# pylint: disable-next=missing-class-docstring,too-few-public-methods | ||
class Meta(TypedModelMeta): | ||
abstract = True | ||
verbose_name = _("user") | ||
verbose_name_plural = _("users") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
""" | ||
© Ocado Group | ||
Created on 06/11/2024 at 17:31:32(+00:00). | ||
""" | ||
|
||
import typing as t | ||
|
||
from django.contrib.auth import SESSION_KEY | ||
from django.contrib.sessions.backends.db import SessionStore | ||
from django.utils import timezone | ||
|
||
if t.TYPE_CHECKING: | ||
from .abstract_base_session import AbstractBaseSession | ||
from .abstract_base_user import AbstractBaseUser | ||
|
||
AnyAbstractBaseSession = t.TypeVar( | ||
"AnyAbstractBaseSession", bound=AbstractBaseSession | ||
) | ||
AnyAbstractBaseUser = t.TypeVar( | ||
"AnyAbstractBaseUser", bound=AbstractBaseUser | ||
) | ||
else: | ||
AnyAbstractBaseSession = t.TypeVar("AnyAbstractBaseSession") | ||
AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser") | ||
|
||
|
||
class BaseSessionStore( | ||
SessionStore, | ||
t.Generic[AnyAbstractBaseSession, AnyAbstractBaseUser], | ||
): | ||
""" | ||
Base session store class to be inherited by all session store classes. | ||
https://docs.djangoproject.com/en/3.2/topics/http/sessions/#example | ||
""" | ||
|
||
@classmethod | ||
def get_model_class(cls) -> t.Type[AnyAbstractBaseSession]: | ||
# pylint: disable-next=no-member | ||
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] | ||
0 | ||
] | ||
|
||
@classmethod | ||
def get_user_class(cls) -> t.Type[AnyAbstractBaseUser]: | ||
# pylint: disable-next=no-member | ||
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] | ||
1 | ||
] | ||
|
||
def associate_session_to_user( | ||
self, session: AnyAbstractBaseSession, user_id: int | ||
): | ||
"""Associate an anon session to a user. | ||
Args: | ||
session: The anon session. | ||
user_id: The user to associate. | ||
""" | ||
session.user = self.get_user_class().objects.get(id=user_id) | ||
|
||
def create_model_instance(self, data): | ||
try: | ||
user_id = int(data.get(SESSION_KEY)) | ||
except (ValueError, TypeError): | ||
# Create an anon session. | ||
return super().create_model_instance(data) | ||
|
||
model_class = self.get_model_class() | ||
|
||
try: | ||
session = model_class.objects.get(user_id=user_id) | ||
except model_class.DoesNotExist: | ||
session = model_class.objects.get(session_key=self.session_key) | ||
self.associate_session_to_user(session, user_id) | ||
|
||
session.session_data = self.encode(data) | ||
|
||
return session | ||
|
||
@classmethod | ||
def clear_expired(cls, user_id=None): | ||
session_query = cls.get_model_class().objects.filter( | ||
expire_date__lt=timezone.now() | ||
) | ||
|
||
if user_id is not None: | ||
session_query = session_query.filter(user_id=user_id) | ||
|
||
session_query.delete() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters