Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2fa flow #9

Merged
merged 11 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ verify_ssl = true
name = "pypi"

[packages]
codeforlife = {ref = "v0.8.0", git = "https://github.com/ocadotechnology/codeforlife-package-python.git"}
codeforlife = {ref = "v0.8.3", git = "https://github.com/ocadotechnology/codeforlife-package-python.git"}
django = "==3.2.20"
djangorestframework = "==3.13.1"
django-cors-headers = "==4.1.0"
Expand Down
10 changes: 5 additions & 5 deletions backend/Pipfile.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions backend/api/permissions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from codeforlife.user.models import User
from rest_framework.permissions import BasePermission
from rest_framework.request import Request
from rest_framework.views import View


class UserHasSessionAuthFactors(BasePermission):
def has_permission(self, request: Request, view: View):
return (
isinstance(request.user, User)
and request.user.session.session_auth_factors.exists()
)
16 changes: 12 additions & 4 deletions backend/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from codeforlife.tests import CronTestCase
from codeforlife.user.models import AuthFactor, User
from django.core import management
from django.http import HttpResponse
from django.test import TestCase
from django.urls import reverse
from django.utils import timezone
Expand All @@ -13,6 +14,15 @@ class TestLoginView(TestCase):
def setUp(self):
self.user = User.objects.get(id=2)

def _get_session_auth_factors(self, response: HttpResponse):
return [
auth_factor
for auth_factor in response.cookies[
"sessionid_httponly_false"
].value.split(",")
if auth_factor != ""
]

def test_post__otp(self):
AuthFactor.objects.create(
user=self.user,
Expand All @@ -28,9 +38,7 @@ def test_post__otp(self):
)

assert response.status_code == 200
self.assertDictEqual(
response.json(), {"auth_factors": [AuthFactor.Type.OTP]}
)
assert self._get_session_auth_factors(response) == [AuthFactor.Type.OTP]

self.user.userprofile.otp_secret = pyotp.random_base32()
self.user.userprofile.save()
Expand All @@ -45,7 +53,7 @@ def test_post__otp(self):
)

assert response.status_code == 200
self.assertDictEqual(response.json(), {"auth_factors": []})
assert self._get_session_auth_factors(response) == []


class TestClearExpiredView(CronTestCase):
Expand Down
22 changes: 17 additions & 5 deletions backend/api/urls.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
from django.urls import include, path, re_path

from .views import ClearExpiredView, LoginView
from .views import ClearExpiredView, LoginOptionsView, LoginView

urlpatterns = [
path(
"session/",
include(
[
re_path(
r"^login/(?P<form>email|username|user-id|otp|otp-bypass-token)/$",
LoginView.as_view(),
name="login",
path(
"login/",
include(
[
path(
"options/",
LoginOptionsView.as_view(),
name="login-options",
),
re_path(
r"^(?P<form>email|username|user-id|otp|otp-bypass-token)/$",
LoginView.as_view(),
name="login",
),
]
),
),
path(
"clear-expired/",
Expand Down
53 changes: 44 additions & 9 deletions backend/api/views.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import logging

from codeforlife.mixins import CronMixin
from codeforlife.request import HttpRequest
from codeforlife.request import HttpRequest, Request
from codeforlife.user.models import AuthFactor, User
from common.models import UserSession
from django.conf import settings
from django.contrib.auth import login
from django.contrib.auth.views import LoginView as _LoginView
from django.contrib.sessions.models import Session, SessionManager
from django.core import management
from django.http import JsonResponse
from django.http import HttpResponse, JsonResponse
from rest_framework import status
from rest_framework.response import Response
from rest_framework.views import APIView
Expand All @@ -20,6 +22,7 @@
UserIdAuthForm,
UsernameAuthForm,
)
from .permissions import UserHasSessionAuthFactors


# TODO: add 2FA logic
Expand Down Expand Up @@ -58,20 +61,52 @@ def form_valid(self, form: BaseAuthForm):
# Save session (with data).
self.request.session.save()

return JsonResponse(
{
"auth_factors": list(
self.request.user.session.session_auth_factors.values_list(
"auth_factor__type", flat=True
)
response = HttpResponse()

# Create a non-HTTP-only session cookie with the pending auth factors.
response.set_cookie(
key="sessionid_httponly_false",
value=",".join(
self.request.user.session.session_auth_factors.values_list(
"auth_factor__type", flat=True
)
}
),
max_age=(
None
if settings.SESSION_EXPIRE_AT_BROWSER_CLOSE
else settings.SESSION_COOKIE_AGE
),
secure=settings.SESSION_COOKIE_SECURE,
samesite=settings.SESSION_COOKIE_SAMESITE,
domain=settings.SESSION_COOKIE_DOMAIN,
httponly=False,
)

return response

def form_invalid(self, form: BaseAuthForm):
return JsonResponse(form.errors, status=status.HTTP_400_BAD_REQUEST)


class LoginOptionsView(APIView):
http_method_names = ["get"]
permission_classes = [UserHasSessionAuthFactors]

def get(self, request: Request):
user: User = request.user
session_auth_factors = user.session.session_auth_factors

response_data = {"id": user.id}
if session_auth_factors.filter(
auth_factor__type=AuthFactor.Type.OTP
).exists():
response_data[
"otp_bypass_token_exists"
] = user.otp_bypass_tokens.exists()

return Response(response_data)


class ClearExpiredView(CronMixin, APIView):
def get(self, request):
# objects is missing type SessionManager
Expand Down