Skip to content

Commit

Permalink
Contributor backend 21 (#368)
Browse files Browse the repository at this point in the history
* new cfl package

* fix linting errors

* disable=too-many-ancestors

* abstract base login view and form

* add todo

* new cfl package

* new cfl package
  • Loading branch information
SKairinos authored Nov 13, 2024
1 parent a0821a5 commit 9d55561
Show file tree
Hide file tree
Showing 14 changed files with 31 additions and 140 deletions.
4 changes: 2 additions & 2 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ name = "pypi"
# 5. Run `pipenv install --dev` in your terminal.

[packages]
codeforlife = {ref = "v0.20.0", git = "https://github.com/ocadotechnology/codeforlife-package-python.git"}
codeforlife = {ref = "v0.21.0", git = "https://github.com/ocadotechnology/codeforlife-package-python.git"}
# 🚫 Don't add [packages] below that are inherited from the CFL package.
pyjwt = "==2.6.0" # TODO: upgrade to latest version
# TODO: Needed by RR. Remove when RR has moved to new system.
Expand All @@ -32,7 +32,7 @@ django-sekizai = "==2.0.0"
django-classy-tags = "==2.0.0"

[dev-packages]
codeforlife = {ref = "v0.20.0", git = "https://github.com/ocadotechnology/codeforlife-package-python.git", extras = ["dev"]}
codeforlife = {ref = "v0.21.0", git = "https://github.com/ocadotechnology/codeforlife-package-python.git", extras = ["dev"]}
# codeforlife = {file = "../codeforlife-package-python", editable = true, extras = ["dev"]}
# 🚫 Don't add [dev-packages] below that are inherited from the CFL package.

Expand Down
6 changes: 3 additions & 3 deletions Pipfile.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion src/api/serializers/auth_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rest_framework import serializers


# pylint: disable-next=missing-class-docstring
# pylint: disable-next=missing-class-docstring,too-many-ancestors
class AuthFactorSerializer(ModelSerializer[User, AuthFactor]):
class Meta:
model = AuthFactor
Expand Down
4 changes: 3 additions & 1 deletion src/api/serializers/auth_factor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from ..views import AuthFactorViewSet
from .auth_factor import AuthFactorSerializer

# pylint: disable=missing-class-docstring
# pylint: disable=too-many-ancestors


# pylint: disable-next=missing-class-docstring
class TestAuthFactorSerializer(ModelSerializerTestCase[User, AuthFactor]):
model_serializer_class = AuthFactorSerializer
fixtures = ["school_2", "non_school_teacher"]
Expand Down
1 change: 1 addition & 0 deletions src/api/serializers/klass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .klass import WriteClassSerializer

# pylint: disable=missing-class-docstring
# pylint: disable=too-many-ancestors


class TestWriteClassSerializer(ModelSerializerTestCase[User, Class]):
Expand Down
2 changes: 1 addition & 1 deletion src/api/serializers/school_teacher_invitation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)

# pylint: disable=missing-class-docstring
# pylint: disable=too-many-ancestors


class TestSchoolTeacherInvitationSerializer(
Expand Down Expand Up @@ -68,7 +69,6 @@ def test_create(self, invitation_make_password: Mock):
invitation_make_password.assert_called_once()


# pylint: disable-next=missing-class-docstring
class TestRefreshSchoolTeacherInvitationSerializer(
ModelSerializerTestCase[User, SchoolTeacherInvitation]
):
Expand Down
4 changes: 3 additions & 1 deletion src/api/serializers/school_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from ..views.school import SchoolViewSet
from .school import SchoolSerializer

# pylint: disable=missing-class-docstring
# pylint: disable=too-many-ancestors


# pylint: disable-next=missing-class-docstring
class TestSchoolSerializer(ModelSerializerTestCase[User, School]):
model_serializer_class = SchoolSerializer
fixtures = ["school_1"]
Expand Down
1 change: 1 addition & 0 deletions src/api/serializers/teacher_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)

# pylint: disable=missing-class-docstring
# pylint: disable=too-many-ancestors


class TestCreateTeacherSerializer(ModelSerializerTestCase[User, Teacher]):
Expand Down
1 change: 1 addition & 0 deletions src/api/serializers/user_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)

# pylint: disable=missing-class-docstring
# pylint: disable=too-many-ancestors


class TestBaseUserSerializer(ModelSerializerTestCase[User, User]):
Expand Down
1 change: 1 addition & 0 deletions src/api/views/school.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_permissions(self):

return super().get_permissions()

# pylint: disable-next=missing-function-docstring
def destroy(self, request, *args, **kwargs):
school = self.get_object()

Expand Down
1 change: 1 addition & 0 deletions src/rapid_router/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TypedModelMeta = object


# pylint: disable-next=too-many-ancestors
class User(_User):
"""A Rapid Router user."""

Expand Down
1 change: 1 addition & 0 deletions src/rapid_router/serializers/level_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .level import LockLevelListSerializer, LockLevelSerializer

# pylint: disable=missing-class-docstring
# pylint: disable=too-many-ancestors


class TestLockLevelSerializer(ModelSerializerTestCase[User, Level]):
Expand Down
76 changes: 6 additions & 70 deletions src/sso/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,77 +3,13 @@
Created on 01/12/2023 at 16:00:24(+00:00).
"""

from codeforlife.forms import BaseLoginForm
from codeforlife.user.models import User
from django import forms
from django.contrib.auth import authenticate
from django.core.exceptions import ValidationError
from django.core.handlers.wsgi import WSGIRequest
from django.core.validators import RegexValidator


class BaseLoginForm(forms.Form):
"""
Base login form that all other login forms must inherit.
"""

user: User

def __init__(self, request: WSGIRequest, *args, **kwargs):
self.request = request
super().__init__(*args, **kwargs)

def clean(self):
"""Authenticates a user.
Raises:
ValidationError: If there are form errors.
ValidationError: If the user's credentials were incorrect.
ValidationError: If the user's account is deactivated.
Returns:
The cleaned form data.
"""

if self.errors:
raise ValidationError(
"Found form errors. Skipping authentication.",
code="form_errors",
)

user = authenticate(
self.request,
**{key: self.cleaned_data[key] for key in self.fields.keys()}
)
if user is None:
raise ValidationError(
self.get_invalid_login_error_message(),
code="invalid_login",
)
if not isinstance(user, User):
raise ValidationError(
"Incorrect user class.",
code="incorrect_user_class",
)
self.user = user

if not user.is_active:
raise ValidationError(
"User is not active",
code="user_not_active",
)

return self.cleaned_data

def get_invalid_login_error_message(self) -> str:
"""Returns the error message if the user failed to login.
Raises:
NotImplementedError: If message is not set.
"""
raise NotImplementedError()


class EmailLoginForm(BaseLoginForm):
class EmailLoginForm(BaseLoginForm[User]):
"""Log in with an email address."""

email = forms.EmailField()
Expand All @@ -86,7 +22,7 @@ def get_invalid_login_error_message(self):
)


class OtpLoginForm(BaseLoginForm):
class OtpLoginForm(BaseLoginForm[User]):
"""Log in with an OTP code."""

otp = forms.CharField(
Expand All @@ -99,7 +35,7 @@ def get_invalid_login_error_message(self):
return "Please enter the correct one-time password."


class OtpBypassTokenLoginForm(BaseLoginForm):
class OtpBypassTokenLoginForm(BaseLoginForm[User]):
"""Log in with an OTP-bypass token."""

token = forms.CharField(min_length=8, max_length=8)
Expand All @@ -108,7 +44,7 @@ def get_invalid_login_error_message(self):
return "Must be exactly 8 characters. A token can only be used once."


class StudentLoginForm(BaseLoginForm):
class StudentLoginForm(BaseLoginForm[User]):
"""Log in as a student."""

first_name = forms.CharField()
Expand All @@ -133,7 +69,7 @@ def get_invalid_login_error_message(self):
)


class StudentAutoLoginForm(BaseLoginForm):
class StudentAutoLoginForm(BaseLoginForm[User]):
"""Log in with the user's id."""

student_id = forms.IntegerField(min_value=1)
Expand Down
67 changes: 6 additions & 61 deletions src/sso/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,20 @@
split these views into multiple files.
"""

import json
import logging
import typing as t
from urllib.parse import quote_plus

from codeforlife.mixins import CronMixin
from codeforlife.request import HttpRequest
from codeforlife.user.models import User
from codeforlife.views import BaseLoginView
from common.models import UserSession # type: ignore
from django.conf import settings
from django.contrib.auth import login
from django.contrib.auth.views import LoginView as _LoginView
from django.contrib.sessions.models import Session, SessionManager
from django.core import management
from django.http import JsonResponse
from rest_framework import status
from rest_framework.response import Response
from rest_framework.views import APIView

from .forms import (
BaseLoginForm,
EmailLoginForm,
OtpBypassTokenLoginForm,
OtpLoginForm,
Expand All @@ -36,7 +30,7 @@


# pylint: disable-next=too-many-ancestors
class LoginView(_LoginView):
class LoginView(BaseLoginView[HttpRequest[User], User]):
"""
Extends Django's login view to allow a user to log in using one of the
approved forms.
Expand All @@ -45,8 +39,6 @@ class LoginView(_LoginView):
industry standard security measures that a login view should have.
"""

request: HttpRequest

def get_form_class(self):
form = self.kwargs["form"]
if form == "login-with-email":
Expand All @@ -62,21 +54,7 @@ def get_form_class(self):

raise NameError(f'Unsupported form: "{form}".')

def get_form_kwargs(self):
form_kwargs = super().get_form_kwargs()
form_kwargs["data"] = json.loads(self.request.body)

return form_kwargs

def form_valid(self, form: BaseLoginForm): # type: ignore
user = form.user

# Clear expired sessions.
self.request.session.clear_expired(user.pk)

# Create session (without data).
login(self.request, user)

def get_session_metadata(self, user):
# TODO: use google analytics
user_session: t.Dict[str, t.Any] = {"user": user}
if self.get_form_class() in [StudentAutoLoginForm, StudentLoginForm]:
Expand All @@ -88,17 +66,13 @@ def form_valid(self, form: BaseLoginForm): # type: ignore
)
UserSession.objects.create(**user_session)

# Save session (with data).
self.request.session.save()

user_type = "indy"
if user.teacher:
user_type = "teacher"
elif user.student and user.student.class_field:
user_type = "student"

# Get session metadata.
session_metadata = {
return {
"user_id": user.id,
"user_type": user_type,
"auth_factors": list(
Expand All @@ -109,37 +83,8 @@ def form_valid(self, form: BaseLoginForm): # type: ignore
"otp_bypass_token_exists": user.otp_bypass_tokens.exists(),
}

# Return session metadata in response and a non-HTTP-only cookie.
response = JsonResponse(session_metadata)
response.set_cookie(
key=settings.SESSION_METADATA_COOKIE_NAME,
value=quote_plus(
json.dumps(
session_metadata,
separators=(",", ":"),
indent=None,
)
),
max_age=(
None
if settings.SESSION_EXPIRE_AT_BROWSER_CLOSE
else settings.SESSION_COOKIE_AGE
),
secure=settings.SESSION_COOKIE_SECURE,
samesite=t.cast(
t.Optional[t.Literal["Lax", "Strict", "None", False]],
settings.SESSION_COOKIE_SAMESITE,
),
domain=settings.SESSION_COOKIE_DOMAIN,
httponly=False,
)

return response

def form_invalid(self, form: BaseLoginForm): # type: ignore
return JsonResponse(form.errors, status=status.HTTP_400_BAD_REQUEST)


# TODO: move to python package and make work on AWS
class ClearExpiredView(CronMixin, APIView): # type: ignore
"""Clear all expired sessions."""

Expand Down

0 comments on commit 9d55561

Please sign in to comment.