Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contributor backend 21 #368

Merged
merged 7 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ name = "pypi"
# 5. Run `pipenv install --dev` in your terminal.

[packages]
codeforlife = {ref = "v0.20.0", git = "https://github.com/ocadotechnology/codeforlife-package-python.git"}
codeforlife = {ref = "v0.21.0", git = "https://github.com/ocadotechnology/codeforlife-package-python.git"}
# 🚫 Don't add [packages] below that are inherited from the CFL package.
pyjwt = "==2.6.0" # TODO: upgrade to latest version
# TODO: Needed by RR. Remove when RR has moved to new system.
Expand All @@ -32,7 +32,7 @@ django-sekizai = "==2.0.0"
django-classy-tags = "==2.0.0"

[dev-packages]
codeforlife = {ref = "v0.20.0", git = "https://github.com/ocadotechnology/codeforlife-package-python.git", extras = ["dev"]}
codeforlife = {ref = "v0.21.0", git = "https://github.com/ocadotechnology/codeforlife-package-python.git", extras = ["dev"]}
# codeforlife = {file = "../codeforlife-package-python", editable = true, extras = ["dev"]}
# 🚫 Don't add [dev-packages] below that are inherited from the CFL package.

Expand Down
6 changes: 3 additions & 3 deletions Pipfile.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion src/api/serializers/auth_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rest_framework import serializers


# pylint: disable-next=missing-class-docstring
# pylint: disable-next=missing-class-docstring,too-many-ancestors
class AuthFactorSerializer(ModelSerializer[User, AuthFactor]):
class Meta:
model = AuthFactor
Expand Down
4 changes: 3 additions & 1 deletion src/api/serializers/auth_factor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from ..views import AuthFactorViewSet
from .auth_factor import AuthFactorSerializer

# pylint: disable=missing-class-docstring
# pylint: disable=too-many-ancestors


# pylint: disable-next=missing-class-docstring
class TestAuthFactorSerializer(ModelSerializerTestCase[User, AuthFactor]):
model_serializer_class = AuthFactorSerializer
fixtures = ["school_2", "non_school_teacher"]
Expand Down
1 change: 1 addition & 0 deletions src/api/serializers/klass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .klass import WriteClassSerializer

# pylint: disable=missing-class-docstring
# pylint: disable=too-many-ancestors


class TestWriteClassSerializer(ModelSerializerTestCase[User, Class]):
Expand Down
2 changes: 1 addition & 1 deletion src/api/serializers/school_teacher_invitation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)

# pylint: disable=missing-class-docstring
# pylint: disable=too-many-ancestors


class TestSchoolTeacherInvitationSerializer(
Expand Down Expand Up @@ -68,7 +69,6 @@ def test_create(self, invitation_make_password: Mock):
invitation_make_password.assert_called_once()


# pylint: disable-next=missing-class-docstring
class TestRefreshSchoolTeacherInvitationSerializer(
ModelSerializerTestCase[User, SchoolTeacherInvitation]
):
Expand Down
4 changes: 3 additions & 1 deletion src/api/serializers/school_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from ..views.school import SchoolViewSet
from .school import SchoolSerializer

# pylint: disable=missing-class-docstring
# pylint: disable=too-many-ancestors


# pylint: disable-next=missing-class-docstring
class TestSchoolSerializer(ModelSerializerTestCase[User, School]):
model_serializer_class = SchoolSerializer
fixtures = ["school_1"]
Expand Down
1 change: 1 addition & 0 deletions src/api/serializers/teacher_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)

# pylint: disable=missing-class-docstring
# pylint: disable=too-many-ancestors


class TestCreateTeacherSerializer(ModelSerializerTestCase[User, Teacher]):
Expand Down
1 change: 1 addition & 0 deletions src/api/serializers/user_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)

# pylint: disable=missing-class-docstring
# pylint: disable=too-many-ancestors


class TestBaseUserSerializer(ModelSerializerTestCase[User, User]):
Expand Down
1 change: 1 addition & 0 deletions src/api/views/school.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_permissions(self):

return super().get_permissions()

# pylint: disable-next=missing-function-docstring
def destroy(self, request, *args, **kwargs):
school = self.get_object()

Expand Down
1 change: 1 addition & 0 deletions src/rapid_router/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TypedModelMeta = object


# pylint: disable-next=too-many-ancestors
class User(_User):
"""A Rapid Router user."""

Expand Down
1 change: 1 addition & 0 deletions src/rapid_router/serializers/level_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .level import LockLevelListSerializer, LockLevelSerializer

# pylint: disable=missing-class-docstring
# pylint: disable=too-many-ancestors


class TestLockLevelSerializer(ModelSerializerTestCase[User, Level]):
Expand Down
76 changes: 6 additions & 70 deletions src/sso/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,77 +3,13 @@
Created on 01/12/2023 at 16:00:24(+00:00).
"""

from codeforlife.forms import BaseLoginForm
from codeforlife.user.models import User
from django import forms
from django.contrib.auth import authenticate
from django.core.exceptions import ValidationError
from django.core.handlers.wsgi import WSGIRequest
from django.core.validators import RegexValidator


class BaseLoginForm(forms.Form):
"""
Base login form that all other login forms must inherit.
"""

user: User

def __init__(self, request: WSGIRequest, *args, **kwargs):
self.request = request
super().__init__(*args, **kwargs)

def clean(self):
"""Authenticates a user.
Raises:
ValidationError: If there are form errors.
ValidationError: If the user's credentials were incorrect.
ValidationError: If the user's account is deactivated.
Returns:
The cleaned form data.
"""

if self.errors:
raise ValidationError(
"Found form errors. Skipping authentication.",
code="form_errors",
)

user = authenticate(
self.request,
**{key: self.cleaned_data[key] for key in self.fields.keys()}
)
if user is None:
raise ValidationError(
self.get_invalid_login_error_message(),
code="invalid_login",
)
if not isinstance(user, User):
raise ValidationError(
"Incorrect user class.",
code="incorrect_user_class",
)
self.user = user

if not user.is_active:
raise ValidationError(
"User is not active",
code="user_not_active",
)

return self.cleaned_data

def get_invalid_login_error_message(self) -> str:
"""Returns the error message if the user failed to login.
Raises:
NotImplementedError: If message is not set.
"""
raise NotImplementedError()


class EmailLoginForm(BaseLoginForm):
class EmailLoginForm(BaseLoginForm[User]):
"""Log in with an email address."""

email = forms.EmailField()
Expand All @@ -86,7 +22,7 @@ def get_invalid_login_error_message(self):
)


class OtpLoginForm(BaseLoginForm):
class OtpLoginForm(BaseLoginForm[User]):
"""Log in with an OTP code."""

otp = forms.CharField(
Expand All @@ -99,7 +35,7 @@ def get_invalid_login_error_message(self):
return "Please enter the correct one-time password."


class OtpBypassTokenLoginForm(BaseLoginForm):
class OtpBypassTokenLoginForm(BaseLoginForm[User]):
"""Log in with an OTP-bypass token."""

token = forms.CharField(min_length=8, max_length=8)
Expand All @@ -108,7 +44,7 @@ def get_invalid_login_error_message(self):
return "Must be exactly 8 characters. A token can only be used once."


class StudentLoginForm(BaseLoginForm):
class StudentLoginForm(BaseLoginForm[User]):
"""Log in as a student."""

first_name = forms.CharField()
Expand All @@ -133,7 +69,7 @@ def get_invalid_login_error_message(self):
)


class StudentAutoLoginForm(BaseLoginForm):
class StudentAutoLoginForm(BaseLoginForm[User]):
"""Log in with the user's id."""

student_id = forms.IntegerField(min_value=1)
Expand Down
67 changes: 6 additions & 61 deletions src/sso/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,20 @@
split these views into multiple files.
"""

import json
import logging
import typing as t
from urllib.parse import quote_plus

from codeforlife.mixins import CronMixin
from codeforlife.request import HttpRequest
from codeforlife.user.models import User
from codeforlife.views import BaseLoginView
from common.models import UserSession # type: ignore
from django.conf import settings
from django.contrib.auth import login
from django.contrib.auth.views import LoginView as _LoginView
from django.contrib.sessions.models import Session, SessionManager
from django.core import management
from django.http import JsonResponse
from rest_framework import status
from rest_framework.response import Response
from rest_framework.views import APIView

from .forms import (
BaseLoginForm,
EmailLoginForm,
OtpBypassTokenLoginForm,
OtpLoginForm,
Expand All @@ -36,7 +30,7 @@


# pylint: disable-next=too-many-ancestors
class LoginView(_LoginView):
class LoginView(BaseLoginView[HttpRequest[User], User]):
"""
Extends Django's login view to allow a user to log in using one of the
approved forms.
Expand All @@ -45,8 +39,6 @@ class LoginView(_LoginView):
industry standard security measures that a login view should have.
"""

request: HttpRequest

def get_form_class(self):
form = self.kwargs["form"]
if form == "login-with-email":
Expand All @@ -62,21 +54,7 @@ def get_form_class(self):

raise NameError(f'Unsupported form: "{form}".')

def get_form_kwargs(self):
form_kwargs = super().get_form_kwargs()
form_kwargs["data"] = json.loads(self.request.body)

return form_kwargs

def form_valid(self, form: BaseLoginForm): # type: ignore
user = form.user

# Clear expired sessions.
self.request.session.clear_expired(user.pk)

# Create session (without data).
login(self.request, user)

def get_session_metadata(self, user):
# TODO: use google analytics
user_session: t.Dict[str, t.Any] = {"user": user}
if self.get_form_class() in [StudentAutoLoginForm, StudentLoginForm]:
Expand All @@ -88,17 +66,13 @@ def form_valid(self, form: BaseLoginForm): # type: ignore
)
UserSession.objects.create(**user_session)

# Save session (with data).
self.request.session.save()

user_type = "indy"
if user.teacher:
user_type = "teacher"
elif user.student and user.student.class_field:
user_type = "student"

# Get session metadata.
session_metadata = {
return {
"user_id": user.id,
"user_type": user_type,
"auth_factors": list(
Expand All @@ -109,37 +83,8 @@ def form_valid(self, form: BaseLoginForm): # type: ignore
"otp_bypass_token_exists": user.otp_bypass_tokens.exists(),
}

# Return session metadata in response and a non-HTTP-only cookie.
response = JsonResponse(session_metadata)
response.set_cookie(
key=settings.SESSION_METADATA_COOKIE_NAME,
value=quote_plus(
json.dumps(
session_metadata,
separators=(",", ":"),
indent=None,
)
),
max_age=(
None
if settings.SESSION_EXPIRE_AT_BROWSER_CLOSE
else settings.SESSION_COOKIE_AGE
),
secure=settings.SESSION_COOKIE_SECURE,
samesite=t.cast(
t.Optional[t.Literal["Lax", "Strict", "None", False]],
settings.SESSION_COOKIE_SAMESITE,
),
domain=settings.SESSION_COOKIE_DOMAIN,
httponly=False,
)

return response

def form_invalid(self, form: BaseLoginForm): # type: ignore
return JsonResponse(form.errors, status=status.HTTP_400_BAD_REQUEST)


# TODO: move to python package and make work on AWS
class ClearExpiredView(CronMixin, APIView): # type: ignore
"""Clear all expired sessions."""

Expand Down