Skip to content

Commit

Permalink
chore: add pyjwt requirement
Browse files Browse the repository at this point in the history
  • Loading branch information
mumarkhan999 committed Nov 6, 2023
1 parent c5d9a5f commit e33d723
Show file tree
Hide file tree
Showing 14 changed files with 173 additions and 223 deletions.
195 changes: 58 additions & 137 deletions lti_consumer/lti_1p3/key_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,15 @@
This handles validating messages sent by the tool and generating
access token with LTI scopes.
"""
import codecs
import copy
import time
import json
import math
import time
import sys
import logging

import jwt
from Cryptodome.PublicKey import RSA
from jwkest import BadSignature, BadSyntax, WrongNumberOfParts, jwk
from jwkest.jwk import RSAKey, load_jwks_from_url
from jwkest.jws import JWS, NoSuitableSigningKeys, UnknownAlgorithm
from jwkest.jwt import JWT

from . import exceptions

Expand Down Expand Up @@ -50,14 +48,9 @@ def __init__(self, public_key=None, keyset_url=None):
# Import from public key
if public_key:
try:
new_key = RSAKey(use='sig')

# Unescape key before importing it
raw_key = codecs.decode(public_key, 'unicode_escape')

# Import Key and save to internal state
new_key.load_key(RSA.import_key(raw_key))
self.public_key = new_key
algo_obj = jwt.get_algorithm_by_name('RS256')
self.public_key = algo_obj.prepare_key(public_key)
except ValueError as err:
log.warning(
'An error was encountered while loading the LTI tool\'s key from the public key. '
Expand All @@ -76,7 +69,7 @@ def _get_keyset(self, kid=None):

if self.keyset_url:
try:
keys = load_jwks_from_url(self.keyset_url)
keys = jwt.PyJWKClient(self.keyset_url).get_jwk_set()
except Exception as err:
# Broad Exception is required here because jwkest raises
# an Exception object explicitly.
Expand All @@ -89,13 +82,13 @@ def _get_keyset(self, kid=None):
raise exceptions.NoSuitableKeys() from err
keyset.extend(keys)

if self.public_key and kid:
# Fill in key id of stored key.
# This is needed because if the JWS is signed with a
# key with a kid, pyjwkest doesn't match them with
# keys without kid (kid=None) and fails verification
self.public_key.kid = kid

if self.public_key:
if kid:
# Fill in key id of stored key.
# This is needed because if the JWS is signed with a
# key with a kid, pyjwkest doesn't match them with
# keys without kid (kid=None) and fails verification
self.public_key.kid = kid
# Add to keyset
keyset.append(self.public_key)

Expand All @@ -111,48 +104,24 @@ def validate_and_decode(self, token):
iss, sub, exp, aud and jti claims.
"""
try:
# Get KID from JWT header
jwt = JWT().unpack(token)

# Verify message signature
message = JWS().verify_compact(
token,
keys=self._get_keyset(
jwt.headers.get('kid')
)
)

# If message is valid, check expiration from JWT
if 'exp' in message and message['exp'] < time.time():
log.warning(
'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. '
'The JWT has expired.'
)
raise exceptions.TokenSignatureExpired()

# TODO: Validate other JWT claims

# Else returns decoded message
return message

except NoSuitableSigningKeys as err:
log.warning(
'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. '
'There is no suitable signing key.'
)
raise exceptions.NoSuitableKeys() from err
except (BadSyntax, WrongNumberOfParts) as err:
log.warning(
'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. '
'The JWT is malformed.'
)
raise exceptions.MalformedJwtToken() from err
except BadSignature as err:
log.warning(
'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. '
'The JWT signature is incorrect.'
)
raise exceptions.BadJwtSignature() from err
key_set = self._get_keyset()
if not key_set:
raise exceptions.NoSuitableKeys()
for i in range(len(key_set)):
try:
message = jwt.decode(
token,
key=key_set[i],
algorithms=['RS256', 'RS512',],
options={'verify_signature': True}
)
return message
except Exception:
if i == len(key_set) - 1:
raise
except Exception as token_error:
exc_info = sys.exc_info()
raise jwt.InvalidTokenError(exc_info[2]) from token_error


class PlatformKeyHandler:
Expand All @@ -171,14 +140,8 @@ def __init__(self, key_pem, kid=None):
if key_pem:
# Import JWK from RSA key
try:
self.key = RSAKey(
# Using the same key ID as client id
# This way we can easily serve multiple public
# keys on the same endpoint and keep all
# LTI 1.3 blocks working
kid=kid,
key=RSA.import_key(key_pem)
)
algo = jwt.get_algorithm_by_name('RS256')
self.key = algo.prepare_key(key_pem)
except ValueError as err:
log.warning(
'An error was encountered while loading the LTI platform\'s key. '
Expand All @@ -203,41 +166,26 @@ def encode_and_sign(self, message, expiration=None):
# Set iat and exp if expiration is set
if expiration:
_message.update({
"iat": int(round(time.time())),
"exp": int(round(time.time()) + expiration),
"iat": int(math.floor(time.time())),
"exp": int(math.floor(time.time()) + expiration),
})

# The class instance that sets up the signing operation
# An RS 256 key is required for LTI 1.3
_jws = JWS(_message, alg="RS256", cty="JWT")

try:
# Encode and sign LTI message
return _jws.sign_compact([self.key])
except NoSuitableSigningKeys as err:
log.warning(
'An error was encountered while signing the OAuth 2.0 access token JWT. '
'There is no suitable signing key.'
)
raise exceptions.NoSuitableKeys() from err
except UnknownAlgorithm as err:
log.warning(
'An error was encountered while signing the OAuth 2.0 access token JWT. '
'There algorithm is unknown.'
)
raise exceptions.MalformedJwtToken() from err
return jwt.encode(_message, self.key, algorithm="RS256")

def get_public_jwk(self):
"""
Export Public JWK
"""
public_keys = jwk.KEYS()
jwk = {"keys": []}

# Only append to keyset if a key exists
if self.key:
public_keys.append(self.key)

return json.loads(public_keys.dump_jwks())
algo_obj = jwt.get_algorithm_by_name('RS256')
public_key = algo_obj.prepare_key(self.key).public_key()
jwk['keys'].append(json.loads(algo_obj.to_jwk(public_key)))
return jwk

def validate_and_decode(self, token, iss=None, aud=None):
"""
Expand All @@ -246,49 +194,22 @@ def validate_and_decode(self, token, iss=None, aud=None):
Validates a token sent by the tool using the platform's RSA Key.
Optionally validate iss and aud claims if provided.
"""
if not self.key:
raise exceptions.RsaKeyNotSet()
try:
# Verify message signature
message = JWS().verify_compact(token, keys=[self.key])

# If message is valid, check expiration from JWT
if 'exp' in message and message['exp'] < time.time():
log.warning(
'An error was encountered while verifying the OAuth 2.0 access token. '
'The JWT has expired.'
)
raise exceptions.TokenSignatureExpired()

# Validate issuer claim (if present)
log_message_base = 'An error was encountered while verifying the OAuth 2.0 access token. '
if iss:
if 'iss' not in message or message['iss'] != iss:
error_message = 'The required iss claim is missing or does not match the expected iss value. '
log_message = log_message_base + error_message

log.warning(log_message)
raise exceptions.InvalidClaimValue(error_message)

# Validate audience claim (if present)
if aud:
if 'aud' not in message or aud not in message['aud']:
error_message = 'The required aud claim is missing.'
log_message = log_message_base + error_message

log.warning(log_message)
raise exceptions.InvalidClaimValue(error_message)

# Else return token contents
message = jwt.decode(
token,
key=self.key.public_key(),
audience=aud,
issuer=iss,
algorithms=['RS256', 'RS512'],
options={
'verify_signature': True,
'verify_aud': True if aud else False
}
)
return message

except NoSuitableSigningKeys as err:
log.warning(
'An error was encountered while verifying the OAuth 2.0 access token. '
'There is no suitable signing key.'
)
raise exceptions.NoSuitableKeys() from err
except BadSyntax as err:
log.warning(
'An error was encountered while verifying the OAuth 2.0 access token. '
'The JWT is malformed.'
)
raise exceptions.MalformedJwtToken() from err
except Exception as token_error:
exc_info = sys.exc_info()
raise jwt.InvalidTokenError(exc_info[2]) from token_error
46 changes: 32 additions & 14 deletions lti_consumer/lti_1p3/tests/test_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
Unit tests for LTI 1.3 consumer implementation
"""

import json
from unittest.mock import patch
from urllib.parse import parse_qs, urlparse
import uuid

import ddt
import jwt
import sys
from Cryptodome.PublicKey import RSA
from django.conf import settings
from django.test.testcases import TestCase
from edx_django_utils.cache import get_cache_key, TieredCache
from jwkest.jwk import load_jwks
from jwkest.jws import JWS
from jwt.api_jwk import PyJWKSet

from lti_consumer.data import Lti1p3LaunchData
from lti_consumer.lti_1p3 import exceptions
Expand All @@ -36,7 +36,9 @@
STATE = "ABCD"
# Consider storing a fixed key
RSA_KEY_ID = "1"
RSA_KEY = RSA.generate(2048).export_key('PEM')
RSA_KEY = RSA.generate(2048)
RSA_PRIVATE_KEY = RSA_KEY.export_key('PEM')
RSA_PUBLIC_KEY = RSA_KEY.public_key().export_key('PEM')


def _generate_token_request_data(token, scope):
Expand Down Expand Up @@ -69,11 +71,11 @@ def setUp(self):
lti_launch_url=LAUNCH_URL,
client_id=CLIENT_ID,
deployment_id=DEPLOYMENT_ID,
rsa_key=RSA_KEY,
rsa_key=RSA_PRIVATE_KEY,
rsa_key_id=RSA_KEY_ID,
redirect_uris=REDIRECT_URIS,
# Use the same key for testing purposes
tool_key=RSA_KEY
tool_key=RSA_PUBLIC_KEY
)

def _setup_lti_launch_data(self):
Expand Down Expand Up @@ -118,9 +120,25 @@ def _decode_token(self, token):
This also tests the public keyset function.
"""
public_keyset = self.lti_consumer.get_public_keyset()
key_set = load_jwks(json.dumps(public_keyset))

return JWS().verify_compact(token, keys=key_set)
keyset = PyJWKSet.from_dict(public_keyset).keys

for i in range(len(keyset)):
try:
message = jwt.decode(
token,
key=keyset[i].key,
algorithms=['RS256', 'RS512'],
options={
'verify_signature': True,
'verify_aud': False
}
)
return message
except Exception as token_error:
if i < len(keyset) - 1:
continue
exc_info = sys.exc_info()
raise jwt.InvalidTokenError(exc_info[2]) from token_error

@ddt.data(
({"client_id": CLIENT_ID, "redirect_uri": LAUNCH_URL, "nonce": STATE, "state": STATE}, True),
Expand Down Expand Up @@ -558,7 +576,7 @@ def test_access_token_invalid_jwt(self):
"""
request_data = _generate_token_request_data("invalid_jwt", "")

with self.assertRaises(exceptions.MalformedJwtToken):
with self.assertRaises(jwt.exceptions.InvalidTokenError):
self.lti_consumer.access_token(request_data)

def test_access_token_no_acs(self):
Expand Down Expand Up @@ -686,11 +704,11 @@ def setUp(self):
lti_launch_url=LAUNCH_URL,
client_id=CLIENT_ID,
deployment_id=DEPLOYMENT_ID,
rsa_key=RSA_KEY,
rsa_key=RSA_PRIVATE_KEY,
rsa_key_id=RSA_KEY_ID,
redirect_uris=REDIRECT_URIS,
# Use the same key for testing purposes
tool_key=RSA_KEY
tool_key=RSA_PUBLIC_KEY
)

self.preflight_response = {}
Expand Down Expand Up @@ -930,11 +948,11 @@ def setUp(self):
lti_launch_url=LAUNCH_URL,
client_id=CLIENT_ID,
deployment_id=DEPLOYMENT_ID,
rsa_key=RSA_KEY,
rsa_key=RSA_PRIVATE_KEY,
rsa_key_id=RSA_KEY_ID,
redirect_uris=REDIRECT_URIS,
# Use the same key for testing purposes
tool_key=RSA_KEY
tool_key=RSA_PUBLIC_KEY
)

self.preflight_response = {}
Expand Down
Loading

0 comments on commit e33d723

Please sign in to comment.