Skip to content

Commit

Permalink
fix static type hint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
SKairinos committed Apr 15, 2024
1 parent 684a57e commit 74434a2
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 23 deletions.
4 changes: 2 additions & 2 deletions backend/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
© Ocado Group
Created on 06/02/2024 at 15:13:00(+00:00).
"""
# TODO: Move from common to here and update to match new models pattern
from common.models import SchoolTeacherInvitation

from .school_teacher_invitation import SchoolTeacherInvitation
10 changes: 10 additions & 0 deletions backend/api/models/school_teacher_invitation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""
© Ocado Group
Created on 15/04/2024 at 15:13:33(+01:00).
"""

# TODO: Move from common to here and update to match new models pattern
# pylint: disable-next=unused-import
from common.models import ( # type: ignore[import-untyped]
SchoolTeacherInvitation,
)
4 changes: 3 additions & 1 deletion backend/rapid_router/filters/level.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
Created on 03/04/2024 at 16:37:39(+01:00).
"""

from django_filters import rest_framework as filters
from django_filters import ( # type: ignore[import-untyped] # isort: skip
rest_framework as filters,
)

from ..models import Level

Expand Down
2 changes: 2 additions & 0 deletions backend/rapid_router/models/user.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# TODO: remove this in new system
# mypy: disable-error-code="import-untyped"
"""
© Ocado Group
Created on 05/04/2024 at 12:40:10(+01:00).
Expand Down
2 changes: 1 addition & 1 deletion backend/rapid_router/serializers/level.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from codeforlife.serializers import ModelListSerializer, ModelSerializer
from codeforlife.user.models import Class
from common.models import DailyActivity
from common.models import DailyActivity # type: ignore[import-untyped]
from django.db.models.query import QuerySet
from rest_framework import serializers

Expand Down
2 changes: 1 addition & 1 deletion backend/service/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
AUTOCONFIG_INDEX_VIEW = "home"
SITE_ID = 1

PIPELINE = {}
PIPELINE = {} # type: ignore[var-annotated]

FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:3000")

Expand Down
4 changes: 2 additions & 2 deletions backend/service/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""

from aimmo import urls as aimmo_urls
from aimmo import urls as aimmo_urls # type: ignore[import-untyped]
from codeforlife.urls import service_urlpatterns
from django.urls import include, path
from portal.views.aimmo.dashboard import (
from portal.views.aimmo.dashboard import ( # type: ignore[import-untyped]
StudentAimmoDashboard,
TeacherAimmoDashboard,
)
Expand Down
19 changes: 13 additions & 6 deletions backend/sso/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
Created on 01/12/2023 at 16:00:24(+00:00).
"""

import typing as t

from codeforlife.user.models import User
from django import forms
from django.contrib.auth import authenticate
from django.contrib.auth.base_user import AbstractBaseUser
from django.contrib.auth.forms import UsernameField
from django.core.exceptions import ValidationError
from django.core.handlers.wsgi import WSGIRequest
Expand All @@ -19,9 +18,10 @@ class BaseAuthForm(forms.Form):
Base authentication form that all other authentication forms must inherit.
"""

user: User

def __init__(self, request: WSGIRequest, *args, **kwargs):
self.request = request
self.user: t.Optional[AbstractBaseUser] = None
super().__init__(*args, **kwargs)

def clean(self):
Expand All @@ -42,16 +42,23 @@ def clean(self):
code="form_errors",
)

self.user = authenticate(
user = authenticate(
self.request,
**{key: self.cleaned_data[key] for key in self.fields.keys()}
)
if self.user is None:
if user is None:
raise ValidationError(
self.get_invalid_login_error_message(),
code="invalid_login",
)
if not self.user.is_active:
if not isinstance(user, User):
raise ValidationError(
"Incorrect user class.",
code="incorrect_user_class",
)
self.user = user

if not user.is_active:
raise ValidationError(
"User is not active",
code="user_not_active",
Expand Down
6 changes: 4 additions & 2 deletions backend/sso/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,20 @@
from unittest.mock import patch

import pyotp
from codeforlife.tests import CronTestCase
from codeforlife.tests import Client, CronTestCase, TestCase
from codeforlife.user.models import AuthFactor, User
from django.core import management
from django.http import HttpResponse
from django.test import TestCase
from django.urls import reverse
from django.utils import timezone


class TestLoginView(TestCase):
"""Test the login view."""

client: Client
client_class = Client

def setUp(self):
self.user = User.objects.get(id=2)

Expand Down
19 changes: 11 additions & 8 deletions backend/sso/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,20 @@ def get_form_class(self):
raise NameError(f'Unsupported form: "{form}".')

def form_valid(self, form: BaseAuthForm): # type: ignore
if form.user is None:
raise ValueError("User should NOT be none.")
user = form.user

# Clear expired sessions.
self.request.session.clear_expired(form.user.pk)
self.request.session.clear_expired(user.pk)

# Create session (without data).
login(self.request, form.user)
user = self.request.user
login(self.request, user)

# TODO: use google analytics
user_session = {"user": form.user}
user_session: t.Dict[str, t.Any] = {"user": user}
if self.get_form_class() in [UsernameAuthForm, UserIdAuthForm]:
user_session["class_field"] = form.user.new_student.class_field
user_session[
"class_field"
] = user.new_student.class_field # type: ignore[attr-defined]
user_session["login_type"] = (
"direct" if "user_id" in self.request.POST else "classform"
)
Expand Down Expand Up @@ -108,7 +108,10 @@ def form_valid(self, form: BaseAuthForm): # type: ignore
else settings.SESSION_COOKIE_AGE
),
secure=settings.SESSION_COOKIE_SECURE,
samesite=settings.SESSION_COOKIE_SAMESITE,
samesite=t.cast(
t.Optional[t.Literal["Lax", "Strict", "None", False]],
settings.SESSION_COOKIE_SAMESITE,
),
domain=settings.SESSION_COOKIE_DOMAIN,
httponly=False,
)
Expand Down

0 comments on commit 74434a2

Please sign in to comment.