Skip to content

Commit

Permalink
Refactor and ported some necessary python-jose jwk code.
Browse files Browse the repository at this point in the history
  • Loading branch information
ianliuwk1019 committed Jul 22, 2024
1 parent ea3ed2a commit e2b5c9c
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 32 deletions.
18 changes: 18 additions & 0 deletions server/backend/api/app/integration/bcsc/bcsc_constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,24 @@
import hashlib


class JWEError(Exception):
"""Base error for all JWE errors"""
pass


class JWEParseError(JWEError):
"""Could not parse the JWE string provided"""
pass


class JOSEError(Exception):
pass


class JWKError(JOSEError):
pass


class Algorithms:
# DS Algorithms
NONE = "none"
Expand Down
25 changes: 4 additions & 21 deletions server/backend/api/app/integration/bcsc/bcsc_decryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from struct import pack

from api.app.integration.bcsc import bcsc_jwk
from api.app.integration.bcsc.bcsc_constants import ALGORITHMS
from api.app.integration.bcsc.bcsc_constants import (ALGORITHMS, JWEError,
JWEParseError)
from api.app.utils import utils

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -119,6 +120,7 @@ def _decrypt_and_auth(cek_bytes, enc, cipher_text, iv, aad, auth_tag):
if enc in ALGORITHMS.HMAC_AUTH_TAG:
encryption_key, mac_key, key_len = _get_encryption_key_mac_key_and_key_length_from_cek(cek_bytes, enc)
auth_tag_check = _auth_tag(cipher_text, iv, aad, mac_key, key_len)

# BCSC enc uses algorithm in ALGORITHMS.HMAC_AUTH_TAG, below will not run.
elif enc in ALGORITHMS.GCM:
encryption_key = bcsc_jwk.jwk_construct(cek_bytes, enc)
Expand Down Expand Up @@ -174,7 +176,7 @@ def _jwe_compact_deserialize(jwe_bytes):
# Vector, the JWE Ciphertext, the JWE Authentication Tag, and the
# JWE AAD, following the restriction that no line breaks,
# whitespace, or other additional characters have been used.
jwe_bytes = ensure_binary(jwe_bytes)
jwe_bytes = utils.ensure_binary(jwe_bytes)
try:
header_segment, encrypted_key_segment, iv_segment, cipher_text_segment, auth_tag_segment = jwe_bytes.split(
b".", 4
Expand Down Expand Up @@ -257,22 +259,3 @@ def _auth_tag(ciphertext, iv, aad, mac_key, tag_length):
auth_tag = signature[0:tag_length]
return auth_tag


def ensure_binary(s):
"""Coerce **s** to bytes."""

if isinstance(s, bytes):
return s
if isinstance(s, str):
return s.encode("utf-8", "strict")
raise TypeError(f"not expecting type '{type(s)}'")


class JWEError(Exception):
"""Base error for all JWE errors"""
pass


class JWEParseError(JWEError):
"""Could not parse the JWE string provided"""
pass
148 changes: 137 additions & 11 deletions server/backend/api/app/integration/bcsc/bcsc_jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@
import hashlib
import hmac

from api.app.integration.bcsc.bcsc_constants import ALGORITHMS
from api.app.integration.bcsc.bcsc_constants import (ALGORITHMS, JWEError,
JWKError)
from api.app.utils import utils
from api.app.utils.utils import base64url_decode
from jose import jwk
from cryptography.exceptions import InvalidTag
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import (Cipher, aead, algorithms,
modes)
from cryptography.hazmat.primitives.padding import PKCS7

# from jose import jwk

# This code partial is from "python-jose" not maintained library.
# https://pypi.org/project/python-jose/
Expand All @@ -23,20 +31,13 @@ def jwk_construct(key_data, algorithm=None):
if not algorithm:
raise JWKError("Unable to find an algorithm for key: %s" % key_data)

key_class = jwk.get_key(algorithm)
# key_class = jwk.get_key(algorithm)
key_class = get_key(algorithm)
if not key_class:
raise JWKError("Unable to find an algorithm for key: %s" % key_data)
return key_class(key_data, algorithm)


class JOSEError(Exception):
pass


class JWKError(JOSEError):
pass


# BCSC uses HMAC and AES for now. Comment out others to
# be easier to be ported.
def get_key(algorithm):
Expand Down Expand Up @@ -202,6 +203,131 @@ def to_dict(self):
}


class CryptographyAESKey(Key):
KEY_128 = (ALGORITHMS.A128GCM, ALGORITHMS.A128GCMKW, ALGORITHMS.A128KW, ALGORITHMS.A128CBC)
KEY_192 = (ALGORITHMS.A192GCM, ALGORITHMS.A192GCMKW, ALGORITHMS.A192KW, ALGORITHMS.A192CBC)
KEY_256 = (
ALGORITHMS.A256GCM,
ALGORITHMS.A256GCMKW,
ALGORITHMS.A256KW,
ALGORITHMS.A128CBC_HS256,
ALGORITHMS.A256CBC,
)
KEY_384 = (ALGORITHMS.A192CBC_HS384,)
KEY_512 = (ALGORITHMS.A256CBC_HS512,)

AES_KW_ALGS = (ALGORITHMS.A128KW, ALGORITHMS.A192KW, ALGORITHMS.A256KW)

MODES = {
ALGORITHMS.A128GCM: modes.GCM,
ALGORITHMS.A192GCM: modes.GCM,
ALGORITHMS.A256GCM: modes.GCM,
ALGORITHMS.A128CBC_HS256: modes.CBC,
ALGORITHMS.A192CBC_HS384: modes.CBC,
ALGORITHMS.A256CBC_HS512: modes.CBC,
ALGORITHMS.A128CBC: modes.CBC,
ALGORITHMS.A192CBC: modes.CBC,
ALGORITHMS.A256CBC: modes.CBC,
ALGORITHMS.A128GCMKW: modes.GCM,
ALGORITHMS.A192GCMKW: modes.GCM,
ALGORITHMS.A256GCMKW: modes.GCM,
ALGORITHMS.A128KW: None,
ALGORITHMS.A192KW: None,
ALGORITHMS.A256KW: None,
}

def __init__(self, key, algorithm):
if algorithm not in ALGORITHMS.AES:
raise JWKError("%s is not a valid AES algorithm" % algorithm)
if algorithm not in ALGORITHMS.SUPPORTED.union(ALGORITHMS.AES_PSEUDO):
raise JWKError("%s is not a supported algorithm" % algorithm)

self._algorithm = algorithm
self._mode = self.MODES.get(self._algorithm)

if algorithm in self.KEY_128 and len(key) != 16:
raise JWKError(f"Key must be 128 bit for alg {algorithm}")
elif algorithm in self.KEY_192 and len(key) != 24:
raise JWKError(f"Key must be 192 bit for alg {algorithm}")
elif algorithm in self.KEY_256 and len(key) != 32:
raise JWKError(f"Key must be 256 bit for alg {algorithm}")
elif algorithm in self.KEY_384 and len(key) != 48:
raise JWKError(f"Key must be 384 bit for alg {algorithm}")
elif algorithm in self.KEY_512 and len(key) != 64:
raise JWKError(f"Key must be 512 bit for alg {algorithm}")

self._key = key

def to_dict(self):
data = {"alg": self._algorithm, "kty": "oct", "k": base64url_encode(self._key)}
return data

# Commented out, no encryption is needed for FAM-BCSC.
# def encrypt(self, plain_text, aad=None):
# plain_text = utils.ensure_binary(plain_text)
# try:
# iv = get_random_bytes(algorithms.AES.block_size // 8)
# mode = self._mode(iv)
# if mode.name == "GCM":
# cipher = aead.AESGCM(self._key)
# cipher_text_and_tag = cipher.encrypt(iv, plain_text, aad)
# cipher_text = cipher_text_and_tag[: len(cipher_text_and_tag) - 16]
# auth_tag = cipher_text_and_tag[-16:]
# else:
# cipher = Cipher(algorithms.AES(self._key), mode, backend=default_backend())
# encryptor = cipher.encryptor()
# padder = PKCS7(algorithms.AES.block_size).padder()
# padded_data = padder.update(plain_text)
# padded_data += padder.finalize()
# cipher_text = encryptor.update(padded_data) + encryptor.finalize()
# auth_tag = None
# return iv, cipher_text, auth_tag
# except Exception as e:
# raise JWEError(e)

def decrypt(self, cipher_text, iv=None, aad=None, tag=None):
cipher_text = utils.ensure_binary(cipher_text)
try:
iv = utils.ensure_binary(iv)
mode = self._mode(iv)
if mode.name == "GCM":
if tag is None:
raise ValueError("tag cannot be None")
cipher = aead.AESGCM(self._key)
cipher_text_and_tag = cipher_text + tag
try:
plain_text = cipher.decrypt(iv, cipher_text_and_tag, aad)
except InvalidTag:
raise JWEError("Invalid JWE Auth Tag")
else:
cipher = Cipher(algorithms.AES(self._key), mode, backend=default_backend())
decryptor = cipher.decryptor()
padded_plain_text = decryptor.update(cipher_text)
padded_plain_text += decryptor.finalize()
unpadder = PKCS7(algorithms.AES.block_size).unpadder()
plain_text = unpadder.update(padded_plain_text)
plain_text += unpadder.finalize()

return plain_text
except Exception as e:
raise JWEError(e)

# Commented out, no encryption is needed for FAM-BCSC.
# def wrap_key(self, key_data):
# key_data = utils.ensure_binary(key_data)
# cipher_text = aes_key_wrap(self._key, key_data, default_backend())
# return cipher_text # IV, cipher text, auth tag

# Commented out, no encryption is needed for FAM-BCSC.
# def unwrap_key(self, wrapped_key):
# wrapped_key = utils.ensure_binary(wrapped_key)
# try:
# plain_text = aes_key_unwrap(self._key, wrapped_key, default_backend())
# except InvalidUnwrap as cause:
# raise JWEError(cause)
# return plain_text


def base64url_encode(input):
"""Helper method to base64url_encode a string.
Expand Down
10 changes: 10 additions & 0 deletions server/backend/api/app/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,13 @@ def base64url_decode(input):
input += b"=" * (4 - rem)

return base64.urlsafe_b64decode(input)


def ensure_binary(s):
"""Coerce **s** to bytes."""

if isinstance(s, bytes):
return s
if isinstance(s, str):
return s.encode("utf-8", "strict")
raise TypeError(f"not expecting type '{type(s)}'")

0 comments on commit e2b5c9c

Please sign in to comment.