Skip to content

Commit

Permalink
Use one security policy class per LTIUser location
Browse files Browse the repository at this point in the history
While this is more verbose than the previous approach with a function to
parametrize the base class behaviour it has a couple of adventages:

- We can use one test to check the right policy for all endpoints
- It makes switching to a more declarative approach easier in the
  future.
  • Loading branch information
marcospri committed May 9, 2024
1 parent 9645139 commit 3905b57
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 139 deletions.
87 changes: 48 additions & 39 deletions lms/security.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import base64
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache, partial
from typing import Callable
from functools import lru_cache

import sentry_sdk
from pyramid.authentication import AuthTktCookieHelper
Expand Down Expand Up @@ -117,7 +116,7 @@ def get_policy(request: Request):

if path in {"/lti_launches", "/content_item_selection"}:
# Actual LTI backed authentication
return LTIUserSecurityPolicy(get_lti_user_from_launch_params)
return LaunchParamsLTIUserPolicy()

if path in {
"/canvas_oauth_callback",
Expand All @@ -126,40 +125,32 @@ def get_policy(request: Request):
"/api/d2l/oauth/callback",
}:
# LTIUser serialized in the state param for the oauth flow
return LTIUserSecurityPolicy(get_lti_user_from_oauth_callback)
return OAuthCallbackLTIUserPolicy()

# LTUser serialized as query param for authorization failures
if (path.startswith("/api") and path.endswith("authorize")) or path in {
# To fetch pages content from LMSes' APIs
"/api/canvas/pages/proxy",
"/api/moodle/pages/proxy",
}:
return LTIUserSecurityPolicy(
partial(get_lti_user_from_bearer_token, location="querystring")
)
# LTUser serialized as query param for authorization failures
return QueryStringBearerTokenLTIUserPolicy()

if path.startswith("/api") or path in {
"/lti/1.3/deep_linking/form_fields",
"/lti/1.1/deep_linking/form_fields",
"/lti/reconfigure",
}:
# LTUser serialized in the headers for API calls from the frontend
return LTIUserSecurityPolicy(
partial(get_lti_user_from_bearer_token, location="headers")
)
return HeadersBearerTokenLTIUserPolicy()

if path in {"/assignment", "/assignment/edit"} or path.startswith(
"/dashboard/launch/assignment/"
):
# LTUser serialized in a from for non deep-linked assignment configuration
return LTIUserSecurityPolicy(
partial(get_lti_user_from_bearer_token, location="form")
)
return FormBearerTokenLTIUserPolicy()

if path.startswith("/dashboard/organization/"):
return LTIUserSecurityPolicy(
partial(get_lti_user_from_bearer_token, location="cookies")
)
return CookiesBearerTokenLTIUserPolicy()

if path in {"/email/preferences", "/email/unsubscribe"}:
return EmailPreferencesSecurityPolicy(
Expand All @@ -175,8 +166,8 @@ def get_policy(request: Request):
class LTIUserSecurityPolicy:
"""Security policy based on the information of an LTIUser."""

def __init__(self, get_lti_user_: Callable[[Request], LTIUser]):
self._get_lti_user = get_lti_user_
def get_lti_user(self, request): # pragma: no cover
raise NotImplementedError()

@staticmethod
def _get_userid(lti_user):
Expand All @@ -200,7 +191,7 @@ def authenticated_userid(self, request):

def identity(self, request) -> Identity | None:
try:
lti_user = self._get_lti_user(request)
lti_user = self.get_lti_user(request)
except Exception: # pylint:disable=broad-exception-caught
# If anything went wrong, no identity
return None
Expand Down Expand Up @@ -228,7 +219,7 @@ def permits(self, request, _context, permission):
try:
# Getting lti_use here again for the potential exception
# side effect and allow us to return DeniedWithException accordingly
self._get_lti_user(request)
self.get_lti_user(request)
except Exception as err: # pylint:disable=broad-exception-caught
return DeniedWithException(err)

Expand All @@ -241,6 +232,42 @@ def forget(self, request): # pragma: no cover
pass


class LaunchParamsLTIUserPolicy(LTIUserSecurityPolicy):
def get_lti_user(self, request) -> LTIUser:
if "id_token" in request.params:
return LTI13AuthSchema(request).lti_user()

return LTI11AuthSchema(request).lti_user()


class OAuthCallbackLTIUserPolicy(LTIUserSecurityPolicy):
def get_lti_user(self, request) -> LTIUser:
return OAuthCallbackSchema(request).lti_user()


class BearerTokenLTIUserPolicy(LTIUserSecurityPolicy):
location: str

def get_lti_user(self, request) -> LTIUser:
return BearerTokenSchema(request).lti_user(location=self.location)


class FormBearerTokenLTIUserPolicy(BearerTokenLTIUserPolicy):
location = "form"


class CookiesBearerTokenLTIUserPolicy(BearerTokenLTIUserPolicy):
location = "cookies"


class HeadersBearerTokenLTIUserPolicy(BearerTokenLTIUserPolicy):
location = "headers"


class QueryStringBearerTokenLTIUserPolicy(BearerTokenLTIUserPolicy):
location = "querystring"


class LMSGoogleSecurityPolicy(GoogleSecurityPolicy):
def identity(self, request) -> Identity | None:
userid = self.authenticated_userid(request)
Expand Down Expand Up @@ -319,24 +346,6 @@ def _permits(identity, permission):
return Denied("denied")


@lru_cache(maxsize=1)
def get_lti_user_from_launch_params(request) -> LTIUser:
if "id_token" in request.params:
return LTI13AuthSchema(request).lti_user()

return LTI11AuthSchema(request).lti_user()


@lru_cache(maxsize=1)
def get_lti_user_from_bearer_token(request, location) -> LTIUser:
return BearerTokenSchema(request).lti_user(location=location)


@lru_cache(maxsize=1)
def get_lti_user_from_oauth_callback(request) -> LTIUser:
return OAuthCallbackSchema(request).lti_user()


@lru_cache(maxsize=1)
def get_lti_user(request) -> LTIUser | None:
"""
Expand Down
Loading

0 comments on commit 3905b57

Please sign in to comment.