diff --git a/backend/api/models/__init__.py b/backend/api/models/__init__.py index 6f468062..ca92dba9 100644 --- a/backend/api/models/__init__.py +++ b/backend/api/models/__init__.py @@ -2,5 +2,5 @@ © Ocado Group Created on 06/02/2024 at 15:13:00(+00:00). """ -# TODO: Move from common to here and update to match new models pattern -from common.models import SchoolTeacherInvitation + +from .school_teacher_invitation import SchoolTeacherInvitation diff --git a/backend/api/models/school_teacher_invitation.py b/backend/api/models/school_teacher_invitation.py new file mode 100644 index 00000000..d5a2ffd5 --- /dev/null +++ b/backend/api/models/school_teacher_invitation.py @@ -0,0 +1,10 @@ +""" +© Ocado Group +Created on 15/04/2024 at 15:13:33(+01:00). +""" + +# TODO: Move from common to here and update to match new models pattern +# pylint: disable-next=unused-import +from common.models import ( # type: ignore[import-untyped] + SchoolTeacherInvitation, +) diff --git a/backend/rapid_router/filters/level.py b/backend/rapid_router/filters/level.py index 27372ad0..adf09c57 100644 --- a/backend/rapid_router/filters/level.py +++ b/backend/rapid_router/filters/level.py @@ -3,7 +3,9 @@ Created on 03/04/2024 at 16:37:39(+01:00). """ -from django_filters import rest_framework as filters +from django_filters import ( # type: ignore[import-untyped] # isort: skip + rest_framework as filters, +) from ..models import Level diff --git a/backend/rapid_router/models/user.py b/backend/rapid_router/models/user.py index 3fdc577f..c945ec16 100644 --- a/backend/rapid_router/models/user.py +++ b/backend/rapid_router/models/user.py @@ -1,3 +1,5 @@ +# TODO: remove this in new system +# mypy: disable-error-code="import-untyped" """ © Ocado Group Created on 05/04/2024 at 12:40:10(+01:00). diff --git a/backend/rapid_router/serializers/level.py b/backend/rapid_router/serializers/level.py index 59522ff6..36d12ee1 100644 --- a/backend/rapid_router/serializers/level.py +++ b/backend/rapid_router/serializers/level.py @@ -8,7 +8,7 @@ from codeforlife.serializers import ModelListSerializer, ModelSerializer from codeforlife.user.models import Class -from common.models import DailyActivity +from common.models import DailyActivity # type: ignore[import-untyped] from django.db.models.query import QuerySet from rest_framework import serializers diff --git a/backend/service/settings.py b/backend/service/settings.py index 32dc188b..8d79df3a 100644 --- a/backend/service/settings.py +++ b/backend/service/settings.py @@ -113,7 +113,7 @@ AUTOCONFIG_INDEX_VIEW = "home" SITE_ID = 1 -PIPELINE = {} +PIPELINE = {} # type: ignore[var-annotated] FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:3000") diff --git a/backend/service/urls.py b/backend/service/urls.py index e43b9480..fafd26d5 100644 --- a/backend/service/urls.py +++ b/backend/service/urls.py @@ -14,10 +14,10 @@ 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) """ -from aimmo import urls as aimmo_urls +from aimmo import urls as aimmo_urls # type: ignore[import-untyped] from codeforlife.urls import service_urlpatterns from django.urls import include, path -from portal.views.aimmo.dashboard import ( +from portal.views.aimmo.dashboard import ( # type: ignore[import-untyped] StudentAimmoDashboard, TeacherAimmoDashboard, ) diff --git a/backend/sso/forms.py b/backend/sso/forms.py index c4d37eda..5308fb36 100644 --- a/backend/sso/forms.py +++ b/backend/sso/forms.py @@ -3,11 +3,10 @@ Created on 01/12/2023 at 16:00:24(+00:00). """ -import typing as t +from codeforlife.user.models import User from django import forms from django.contrib.auth import authenticate -from django.contrib.auth.base_user import AbstractBaseUser from django.contrib.auth.forms import UsernameField from django.core.exceptions import ValidationError from django.core.handlers.wsgi import WSGIRequest @@ -19,9 +18,10 @@ class BaseAuthForm(forms.Form): Base authentication form that all other authentication forms must inherit. """ + user: User + def __init__(self, request: WSGIRequest, *args, **kwargs): self.request = request - self.user: t.Optional[AbstractBaseUser] = None super().__init__(*args, **kwargs) def clean(self): @@ -42,16 +42,23 @@ def clean(self): code="form_errors", ) - self.user = authenticate( + user = authenticate( self.request, **{key: self.cleaned_data[key] for key in self.fields.keys()} ) - if self.user is None: + if user is None: raise ValidationError( self.get_invalid_login_error_message(), code="invalid_login", ) - if not self.user.is_active: + if not isinstance(user, User): + raise ValidationError( + "Incorrect user class.", + code="incorrect_user_class", + ) + self.user = user + + if not user.is_active: raise ValidationError( "User is not active", code="user_not_active", diff --git a/backend/sso/tests/test_views.py b/backend/sso/tests/test_views.py index 82dd193f..4226fd9c 100644 --- a/backend/sso/tests/test_views.py +++ b/backend/sso/tests/test_views.py @@ -8,11 +8,10 @@ from unittest.mock import patch import pyotp -from codeforlife.tests import CronTestCase +from codeforlife.tests import Client, CronTestCase, TestCase from codeforlife.user.models import AuthFactor, User from django.core import management from django.http import HttpResponse -from django.test import TestCase from django.urls import reverse from django.utils import timezone @@ -20,6 +19,9 @@ class TestLoginView(TestCase): """Test the login view.""" + client: Client + client_class = Client + def setUp(self): self.user = User.objects.get(id=2) diff --git a/backend/sso/views.py b/backend/sso/views.py index 0923d44f..6fd03d31 100644 --- a/backend/sso/views.py +++ b/backend/sso/views.py @@ -64,20 +64,20 @@ def get_form_class(self): raise NameError(f'Unsupported form: "{form}".') def form_valid(self, form: BaseAuthForm): # type: ignore - if form.user is None: - raise ValueError("User should NOT be none.") + user = form.user # Clear expired sessions. - self.request.session.clear_expired(form.user.pk) + self.request.session.clear_expired(user.pk) # Create session (without data). - login(self.request, form.user) - user = self.request.user + login(self.request, user) # TODO: use google analytics - user_session = {"user": form.user} + user_session: t.Dict[str, t.Any] = {"user": user} if self.get_form_class() in [UsernameAuthForm, UserIdAuthForm]: - user_session["class_field"] = form.user.new_student.class_field + user_session[ + "class_field" + ] = user.new_student.class_field # type: ignore[attr-defined] user_session["login_type"] = ( "direct" if "user_id" in self.request.POST else "classform" ) @@ -108,7 +108,10 @@ def form_valid(self, form: BaseAuthForm): # type: ignore else settings.SESSION_COOKIE_AGE ), secure=settings.SESSION_COOKIE_SECURE, - samesite=settings.SESSION_COOKIE_SAMESITE, + samesite=t.cast( + t.Optional[t.Literal["Lax", "Strict", "None", False]], + settings.SESSION_COOKIE_SAMESITE, + ), domain=settings.SESSION_COOKIE_DOMAIN, httponly=False, )