Skip to content

Commit

Permalink
abstract user and session
Browse files Browse the repository at this point in the history
  • Loading branch information
SKairinos committed Nov 6, 2024
1 parent 33d68a9 commit b058505
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 56 deletions.
3 changes: 3 additions & 0 deletions codeforlife/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@
Created on 19/01/2024 at 15:20:45(+00:00).
"""

from .abstract_base_session import AbstractBaseSession
from .abstract_base_user import AbstractBaseUser
from .base import *
from .base_session_store import BaseSessionStore
60 changes: 60 additions & 0 deletions codeforlife/models/abstract_base_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""
© Ocado Group
Created on 06/11/2024 at 16:44:56(+00:00).
"""

import typing as t

from django.contrib.auth import get_user_model
from django.contrib.sessions.base_session import (
AbstractBaseSession as _AbstractBaseSession,
)
from django.db import models
from django.utils import timezone
from django.utils.translation import gettext_lazy as _

from .abstract_base_user import AbstractBaseUser

if t.TYPE_CHECKING:
from django_stubs_ext.db.models import TypedModelMeta

from .base_session_store import BaseSessionStore
else:
TypedModelMeta = object


class AbstractBaseSession(_AbstractBaseSession):
"""
Base session class to be inherited by all session classes.
https://docs.djangoproject.com/en/3.2/topics/http/sessions/#example
"""

pk: str # type: ignore[assignment]

user_id: int
user = models.OneToOneField(
t.cast(t.Type[AbstractBaseUser], get_user_model()),
null=True,
blank=True,
on_delete=models.CASCADE,
)

# pylint: disable-next=missing-class-docstring,too-few-public-methods
class Meta(TypedModelMeta):
abstract = True
verbose_name = _("session")
verbose_name_plural = _("sessions")

@property
def is_expired(self):
"""Whether or not this session has expired."""
return self.expire_date < timezone.now()

@property
def store(self):
"""A store instance for this session."""
return self.get_session_store_class()(self.session_key)

@classmethod
def get_session_store_class(cls) -> t.Type["BaseSessionStore"]:
raise NotImplementedError
33 changes: 33 additions & 0 deletions codeforlife/models/abstract_base_user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
© Ocado Group
Created on 06/11/2024 at 16:38:15(+00:00).
"""

import typing as t

from django.contrib.auth.models import AbstractBaseUser as _AbstractBaseUser
from django.utils.translation import gettext_lazy as _

if t.TYPE_CHECKING:
from django_stubs_ext.db.models import TypedModelMeta

from .abstract_base_session import AbstractBaseSession
else:
TypedModelMeta = object


class AbstractBaseUser(_AbstractBaseUser):
"""
Base user class to be inherited by all user classes.
https://docs.djangoproject.com/en/3.2/topics/auth/customizing/#using-a-custom-user-model-when-starting-a-project
"""

id: int
pk: int
session: "AbstractBaseSession"

# pylint: disable-next=missing-class-docstring,too-few-public-methods
class Meta(TypedModelMeta):
abstract = True
verbose_name = _("user")
verbose_name_plural = _("users")
89 changes: 89 additions & 0 deletions codeforlife/models/base_session_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
© Ocado Group
Created on 06/11/2024 at 17:31:32(+00:00).
"""

import typing as t

from django.contrib.auth import SESSION_KEY
from django.contrib.sessions.backends.db import SessionStore
from django.utils import timezone

if t.TYPE_CHECKING:
from .abstract_base_session import AbstractBaseSession
from .abstract_base_user import AbstractBaseUser

AnyAbstractBaseSession = t.TypeVar(
"AnyAbstractBaseSession", bound=AbstractBaseSession
)
AnyAbstractBaseUser = t.TypeVar(
"AnyAbstractBaseUser", bound=AbstractBaseUser
)
else:
AnyAbstractBaseSession = t.TypeVar("AnyAbstractBaseSession")
AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser")


class BaseSessionStore(
SessionStore,
t.Generic[AnyAbstractBaseSession, AnyAbstractBaseUser],
):
"""
Base session store class to be inherited by all session store classes.
https://docs.djangoproject.com/en/3.2/topics/http/sessions/#example
"""

@classmethod
def get_model_class(cls) -> t.Type[AnyAbstractBaseSession]:
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]

@classmethod
def get_user_class(cls) -> t.Type[AnyAbstractBaseUser]:
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
1
]

def associate_session_to_user(
self, session: AnyAbstractBaseSession, user_id: int
):
"""Associate an anon session to a user.
Args:
session: The anon session.
user_id: The user to associate.
"""
session.user = self.get_user_class().objects.get(id=user_id)

def create_model_instance(self, data):
try:
user_id = int(data.get(SESSION_KEY))
except (ValueError, TypeError):
# Create an anon session.
return super().create_model_instance(data)

model_class = self.get_model_class()

try:
session = model_class.objects.get(user_id=user_id)
except model_class.DoesNotExist:
session = model_class.objects.get(session_key=self.session_key)
self.associate_session_to_user(session, user_id)

session.session_data = self.encode(data)

return session

@classmethod
def clear_expired(cls, user_id=None):
session_query = cls.get_model_class().objects.filter(
expire_date__lt=timezone.now()
)

if user_id is not None:
session_query = session_query.filter(user_id=user_id)

session_query.delete()
68 changes: 13 additions & 55 deletions codeforlife/user/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@

import typing as t

from django.contrib.auth import SESSION_KEY
from django.contrib.sessions.backends.db import SessionStore as DBStore
from django.contrib.sessions.base_session import AbstractBaseSession
from django.db import models
from django.db.models.query import QuerySet
from django.utils import timezone

from ...models import AbstractBaseSession, BaseSessionStore
from .user import User

if t.TYPE_CHECKING: # pragma: no cover
Expand All @@ -26,29 +23,20 @@ class Session(AbstractBaseSession):

auth_factors: QuerySet["SessionAuthFactor"]

user = models.OneToOneField(
# TODO: remove in new schema
user = models.OneToOneField( # type: ignore[assignment]
User,
null=True,
blank=True,
on_delete=models.CASCADE,
)

@property
def is_expired(self):
"""Whether or not this session has expired."""
return self.expire_date < timezone.now()

@property
def store(self):
"""A store instance for this session."""
return self.get_session_store_class()(self.session_key)

@classmethod
def get_session_store_class(cls):
return SessionStore


class SessionStore(DBStore):
class SessionStore(BaseSessionStore[Session, User]):
"""
A custom session store interface to support:
1. creating only one session per user;
Expand All @@ -57,44 +45,14 @@ class SessionStore(DBStore):
https://docs.djangoproject.com/en/3.2/topics/http/sessions/#example
"""

@classmethod
def get_model_class(cls):
return Session

def create_model_instance(self, data):
try:
user_id = int(data.get(SESSION_KEY))
except (ValueError, TypeError):
# Create an anon session.
return super().create_model_instance(data)

model_class = self.get_model_class()
def associate_session_to_user(self, session, user_id):
# pylint: disable-next=import-outside-toplevel
from .session_auth_factor import SessionAuthFactor

try:
session = model_class.objects.get(user_id=user_id)
except model_class.DoesNotExist:
# pylint: disable-next=import-outside-toplevel
from .session_auth_factor import SessionAuthFactor

# Associate session to user.
session = model_class.objects.get(session_key=self.session_key)
session.user = User.objects.get(id=user_id)
SessionAuthFactor.objects.bulk_create(
[
SessionAuthFactor(session=session, auth_factor=auth_factor)
for auth_factor in session.user.auth_factors.all()
]
)

session.session_data = self.encode(data)

return session

@classmethod
def clear_expired(cls, user_id: t.Optional[int] = None):
session_query = cls.get_model_class().objects.filter(
expire_date__lt=timezone.now()
super().associate_session_to_user(session, user_id)
SessionAuthFactor.objects.bulk_create(
[
SessionAuthFactor(session=session, auth_factor=auth_factor)
for auth_factor in session.user.auth_factors.all()
]
)
if user_id:
session_query = session_query.filter(user_id=user_id)
session_query.delete()
13 changes: 12 additions & 1 deletion codeforlife/user/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
import string
import typing as t
from datetime import datetime

from common.models import TotalActivity, UserProfile

Expand All @@ -18,6 +19,7 @@
from pyotp import TOTP

from ... import mail
from ...models import AbstractBaseUser
from .klass import Class
from .school import School

Expand All @@ -33,7 +35,16 @@
TypedModelMeta = object


class User(_User):
# TODO: remove in new schema
class _AbstractBaseUser(AbstractBaseUser):
password: str = None # type: ignore[assignment]
last_login: datetime = None # type: ignore[assignment]

class Meta(TypedModelMeta):
abstract = True


class User(_AbstractBaseUser, _User):
"""A proxy to Django's user class."""

_password: t.Optional[str]
Expand Down

0 comments on commit b058505

Please sign in to comment.