diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5397e66 --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +.DS_Store +/venv +*.py[cdo] +*.swp +.idea/ +MANIFEST +build/ +dist/ +docs/_build +.cache +.tox/ +.eggs/ +*.egg-info/ diff --git a/HISTORY.rst b/HISTORY.rst new file mode 100644 index 0000000..d3ee401 --- /dev/null +++ b/HISTORY.rst @@ -0,0 +1,10 @@ +.. :changelog: + +Release History +--------------- + +1.0.0 (2017-01-25) +++++++++++++++++++ + +- Initial public release + diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..dc81130 --- /dev/null +++ b/LICENSE @@ -0,0 +1,13 @@ +Copyright 2017 AOL Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..775988a --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include README.md HISTORY.rst requirements.txt \ No newline at end of file diff --git a/README.rst b/README.rst new file mode 100644 index 0000000..64403c7 --- /dev/null +++ b/README.rst @@ -0,0 +1,178 @@ +mrcrypt: Multi-Region Encryption +================================ + +mrcrypt is a command-line tool that allows you to encrypt secrets in +multiple AWS regions using KMS keys using a technique called `Envelope +Encryption `__. +It is intended to be used with the `AWS Encryption SDK for +Java `__, but could +be used on its own. + +Compatability with the AWS Encryption SDK +''''''''''''''''''''''''''''''''''''''''' + +**All files encrypted with mrcrypt can be decrypted with the AWS +Encryption SDK.** But not all files encrypted with the AWS Encryption +SDK can be decrypted by mrcrypt. + +Currently, mrcrypt only supports the AWS Encryption SDK's default (and +most secure) cryptographic algorithm: + +- Content Type: Framed +- Frame size: 4096 +- Algorithm: ALG\_AES\_256\_GCM\_IV12\_TAG16\_HKDF\_SHA384\_ECDSA\_P384 + +Support for the remaining algorithms are planned, but files encrypted +with the AWS Encryption SDK using one of the other algorithms are +currently not supported in mrcrypt. + +Also, the AWS Encryption SDK creates files using elliptic curve point +compression. Files created with mrcrypt do not use point compression +because they are not currently supported in +`Cryptography `__, a Python +package mrcrypt uses. The uncompressed points are just as secure as the +compressed points, but files are a few bytes larger. The AWS Encryption +SDK can decrypt files that use uncompressed points, meaning all files +created with mrcrypt are compatible with the AWS Encryption SDK. + +Installation +------------ + +To install mrcrypt simply clone the repo, and run ``pip install .`` +inside of the directory: + +:: + + git clone ssh://git@stash.ops.aol.com:2022/identity_services/mrcrypt.git + cd mrcrypt + pip install . + +**Note:** mrcrypt uses the Python package +`Cryptography `__ which depends on +``libffi``. You may need to install it on your system if +``pip install .`` fails. For more specific instructions for your OS: +https://cryptography.io/en/latest/installation/ + +Usage +----- + +:: + + usage: mrcrypt [-h] [-p PROFILE] [-e ENCRYPTION_CONTEXT] [-d] [-o OUTFILE] + {encrypt,decrypt} ... + + Multi Region Encryption. A tool for managing secrets across multiple AWS + regions. + + positional arguments: + {encrypt,decrypt} + + optional arguments: + -h, --help show this help message and exit + -p PROFILE, --profile PROFILE + The profile to use + -e ENCRYPTION_CONTEXT, --encryption_context ENCRYPTION_CONTEXT + An encryption context to use. (Cannot have whitespace) + -d, --debug Enable more output for debugging + -o OUTFILE, --outfile OUTFILE + The file to write the results to + +Both the encrypt, and decrypt commands can encrypt and decrypt files in +directories recursively. + +Named Profiles +'''''''''''''' + +If you have multiple named profiles in your ``~/.aws/credentials`` file, +you can specify one using the ``-p`` argument. + +:: + + mrcrypt -p my_profile encrypt alias/master-key secrets.txt + +Encryption Context +'''''''''''''''''' + +You can specify an `encryption +context `__ +using the ``-e`` argument. This flag takes a JSON object with no spaces: + +:: + + # encrypt + mrcrypt -e '{"key":"value","key2":"value2"}' encrypt alias/master-key secrets.txt + + # decrypt + mrcrypt -e '{"key":"value","key2":"value2"}' decrypt secrets.txt.encrypted + +Output file name +'''''''''''''''' + +If you want to specify the output filename, you can use the ``-o`` +argument. + +``# Encrypt 'file.txt' writing the output into 'encrypted-file.txt' mrcrypt -o encrypted-file.txt encrypt alias/master-key file.txt`` + +By default, when encrypting, mrcrypt will create a file with the same +file name as the input file with ``.encrypted`` appended to the end. +When decrypting, if the file ends with ``.encrypted`` it will write the +plaintext output to a file of the same name but without the +``.encrypted``. + +Encryption +---------- + +:: + + usage: mrcrypt encrypt [-h] [-r REGIONS [REGIONS ...]] [-e ENCRYPTION_CONTEXT] + key_id filename + + Encrypts a file or directory recursively + + positional arguments: + key_id An identifier for a customer master key. + filename The file or directory to encrypt. Use a - to read from + stdin + + optional arguments: + -h, --help show this help message and exit + -r REGIONS [REGIONS ...], --regions REGIONS [REGIONS ...] + A list of regions to encrypt with KMS. End the list + with -- + -e ENCRYPTION_CONTEXT, --encryption_context ENCRYPTION_CONTEXT + An encryption context to use + +**Example:** Encrypt ``secrets.txt`` with the key alias +``alias/master-key`` in the regions ``us-east-1`` and ``us-west-2``: + +``mrcrypt encrypt -r us-east-1 us-west-2 -- alias/master-key secrets.txt`` + +Decryption +---------- + +:: + + usage: mrcrypt decrypt [-h] filename + + Decrypts a file + + positional arguments: + filename The file or directory to decrypt. Use a - to read from stdin + + optional arguments: + -h, --help show this help message and exit + +**Example:** To decrypt ``secrets.txt.encrypted``: + +:: + + mrcrypt decrypt secrets.txt.encrypted + +**Note:** Be careful when decrypting a directory. If the directory +contains files that are not encrypted, it will fail. + +Testing +''''''' + +Running tests for mrcrypt is easy if you have ``tox`` installed. Simply +run ``tox`` at the project's root. diff --git a/mrcrypt/__init__.py b/mrcrypt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mrcrypt/algorithms.py b/mrcrypt/algorithms.py new file mode 100644 index 0000000..2fbf0f5 --- /dev/null +++ b/mrcrypt/algorithms.py @@ -0,0 +1,152 @@ +""" +mrcrypt.algorithms +~~~~~~~~~~~~~~~~~~ + +Contains named tuples that describe different algorithms supported by this tool. +""" +from collections import namedtuple + +#: Max unsigned 16 bit number +GCM_MAX_CONTENT_LENGTH_BITS = (1 << 16) - 1 + +#: All lengths are in bytes, unless stated otherwise. +AlgorithmProfile = namedtuple('AlgorithmProfile', + 'block_size_bits, iv_length, tag_length, max_content_length_bits, ' + 'key_algorithm, key_length, id, data_key_algorithm, ' + 'data_key_length, trailing_signature_algorithm, ' + 'trailing_signature_length_bits') + +alg_aes_128_gcm_iv12_tag16_no_kdf = AlgorithmProfile( + block_size_bits=128, + iv_length=12, + tag_length=16, + max_content_length_bits=GCM_MAX_CONTENT_LENGTH_BITS, + key_algorithm='AES', + key_length=16, + id=0x0014, + data_key_algorithm='AES', + data_key_length=16, + trailing_signature_algorithm=None, + trailing_signature_length_bits=None) + +alg_aes_192_gcm_iv12_tag16_no_kdf = AlgorithmProfile( + block_size_bits=128, + iv_length=12, + tag_length=16, + max_content_length_bits=GCM_MAX_CONTENT_LENGTH_BITS, + key_algorithm='AES', + key_length=24, + id=0x0046, + data_key_algorithm='AES', + data_key_length=24, + trailing_signature_algorithm=None, + trailing_signature_length_bits=None) + +alg_aes_256_gcm_iv12_tag16_no_kdf = AlgorithmProfile( + block_size_bits=128, + iv_length=12, + tag_length=16, + max_content_length_bits=GCM_MAX_CONTENT_LENGTH_BITS, + key_algorithm='AES', + key_length=32, + id=0x0078, + data_key_algorithm='AES', + data_key_length=32, + trailing_signature_algorithm=None, + trailing_signature_length_bits=None) + +alg_aes_128_gcm_iv12_tag16_hkdf_sha256 = AlgorithmProfile( + block_size_bits=128, + iv_length=12, + tag_length=16, + max_content_length_bits=GCM_MAX_CONTENT_LENGTH_BITS, + key_algorithm='AES', + key_length=16, + id=0x0114, + data_key_algorithm='HkdfSHA256', + data_key_length=16, + trailing_signature_algorithm=None, + trailing_signature_length_bits=None) + +alg_aes_192_gcm_iv12_tag16_hkdf_sha256 = AlgorithmProfile( + block_size_bits=128, + iv_length=12, + tag_length=16, + max_content_length_bits=GCM_MAX_CONTENT_LENGTH_BITS, + key_algorithm='AES', + key_length=24, + id=0x0146, + data_key_algorithm='HkdfSHA256', + data_key_length=24, + trailing_signature_algorithm=None, + trailing_signature_length_bits=None) + +alg_aes_256_gcm_iv12_tag16_hkdf_sha256 = AlgorithmProfile( + block_size_bits=128, + iv_length=12, + tag_length=16, + max_content_length_bits=GCM_MAX_CONTENT_LENGTH_BITS, + key_algorithm='AES', + key_length=32, + id=0x0178, + data_key_algorithm='HkdfSHA256', + data_key_length=32, + trailing_signature_algorithm=None, + trailing_signature_length_bits=None) + +alg_aes_128_gcm_iv12_tag16_hkdf_sha256_ecdsa_p256 = AlgorithmProfile( + block_size_bits=128, + iv_length=12, + tag_length=16, + max_content_length_bits=GCM_MAX_CONTENT_LENGTH_BITS, + key_algorithm='AES', + key_length=16, + id=0x0214, + data_key_algorithm='HkdfSHA256', + data_key_length=16, + trailing_signature_algorithm='SHA256withECDSA', + trailing_signature_length_bits=72) + +alg_aes_192_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384 = AlgorithmProfile( + block_size_bits=128, + iv_length=12, + tag_length=16, + max_content_length_bits=GCM_MAX_CONTENT_LENGTH_BITS, + key_algorithm='AES', + key_length=24, + id=0x0346, + data_key_algorithm='HkdfSHA384', + data_key_length=24, + trailing_signature_algorithm='SHA384withECDSA', + trailing_signature_length_bits=104) + +alg_aes_256_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384 = AlgorithmProfile( + block_size_bits=128, + iv_length=12, + tag_length=16, + max_content_length_bits=GCM_MAX_CONTENT_LENGTH_BITS, + key_algorithm='AES', + key_length=32, + id=0x0378, + data_key_algorithm='HkdfSHA384', + data_key_length=32, + trailing_signature_algorithm='SHA384withECDSA', + trailing_signature_length_bits=104) + + +def algorithm_from_id(algorithm_id): + """Retrieves an :class:`AlgorithmProfile` from ``algorithm_id``.""" + mapping = _get_mapping() + try: + return mapping[algorithm_id] + except KeyError: + raise ValueError('The number {} does not map to an algorithm'.format(algorithm_id)) + + +def _get_mapping(): + """Builds a dictionary, mapping IDs to their corresponding :class:`AlgorithmProfile`.""" + return dict((v.id, v) for v in globals().values() if type(v) is AlgorithmProfile) + + +def default_algorithm(): + return alg_aes_256_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384 diff --git a/mrcrypt/cli/__init__.py b/mrcrypt/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mrcrypt/cli/commands.py b/mrcrypt/cli/commands.py new file mode 100644 index 0000000..79f0a12 --- /dev/null +++ b/mrcrypt/cli/commands.py @@ -0,0 +1,110 @@ +import os +import stat + +import mrcrypt.io +from mrcrypt import crypto, exceptions, utils + +ENCRYPTED_FILE_ENDING = ".encrypted" +DECRYPTED_FILE_ENDING = ".decrypted" + + +class EncryptCommand(object): + """Represents the encrypt sub-command from the commandline. + + :param file_path: The path of the file or directory to act on. + :param master_key_id: The CMK to use when generating a data key. + :param outfile: (optional) The file to write to. + :param regions: (optional) A list of regions. + :param profile: (optional) The named profile to use when making requests to AWS. + :param encryption_context: (optional) An encryption context to use during encryption. + """ + def __init__(self, file_path, master_key_id, outfile=None, regions=None, profile=None, + encryption_context=None): + if os.path.isdir(file_path) and outfile: + raise ValueError("Cannot specify an outfile for a directory") + + self.file_path = file_path + self.master_key_id = master_key_id + self.profile = profile + self.encryption_context = encryption_context + self.outfile = outfile + + self.regions = [] if regions is None else regions + + def encrypt(self): + """Handles encryption of both files and directory. If ``self.file_path`` is a directory, + it recursively encrypts all the files in the directory.""" + if os.path.isfile(self.file_path): + self._encrypt_file(self.file_path) + elif os.path.isdir(self.file_path): + for root, subdirs, files in os.walk(self.file_path): + for filename in files: + self._encrypt_file(os.path.join(root, filename)) + else: + raise exceptions.UnsupportedFileObject("{} is not a file".format(self.file_path)) + + def _encrypt_file(self, filename): + """Encrypts the contents of ``filename`` and writes the output.""" + parent_dir = utils.get_parent_dir_path(filename) + + contents = mrcrypt.io.read_plaintext_file(filename) + + message = crypto.encrypt_string(contents, + self.master_key_id, + self.regions, + self.profile, + self.encryption_context) + + outfile = self._generate_outfile(filename) + mrcrypt.io.write_message(outfile, parent_dir, message) + + def _generate_outfile(self, filename): + """Appends a ``.encrypted`` to infile, if ``self.outfile`` is None.""" + return filename + ENCRYPTED_FILE_ENDING if self.outfile is None else self.outfile + + +class DecryptCommand(object): + """Represents the decrypt sub-command from the commandline. + + :param file_path: The path of the file or directory to act on. + :param outfile: (optional) The file to write to. + :param profile: (optional) The named profile to use when making requests to AWS. + """ + def __init__(self, file_path, outfile=None, profile=None): + self.file_path = file_path + self.outfile = outfile + self.profile = profile + + def decrypt(self): + """Handles decryption of both files and a directory. If ``self.file_path`` is a directory, + it recursively decrypts all the files in the directory.""" + if os.path.isfile(self.file_path): + self._decrypt_file(self.file_path) + elif os.path.isdir(self.file_path): + for root, subdirs, files in os.walk(self.file_path): + for filename in files: + self._decrypt_file(os.path.join(root, filename)) + else: + raise exceptions.UnsupportedFileObject("{} is not a file".format(self.file_path)) + + def _decrypt_file(self, filename): + """Decrypts the contents of ``filename`` and writes the output to a file that's read only + by the owner (0400).""" + parent_dir = utils.get_parent_dir_path(filename) + + message = mrcrypt.io.parse_message_file(filename) + content = crypto.decrypt_message(message, profile=self.profile) + + outfile = self._generate_outfile(filename) + mrcrypt.io.write_str(outfile, parent_dir, content, stat.S_IRUSR) + + def _generate_outfile(self, filename): + """If ``self.outfile`` is not None, returns ``self.outfile``. Otherwise it checks for the + ``.encrypted`` extension and removes it. If it doesn't have the ``.encrypted`` extension, + it appends a ``.decrypted`` to ``filename`` and returns it.""" + if self.outfile is None and filename.endswith(ENCRYPTED_FILE_ENDING): + return filename[:-len(ENCRYPTED_FILE_ENDING)] + elif self.outfile is None: + return filename + DECRYPTED_FILE_ENDING + else: + return self.outfile diff --git a/mrcrypt/cli/parser.py b/mrcrypt/cli/parser.py new file mode 100644 index 0000000..b84c216 --- /dev/null +++ b/mrcrypt/cli/parser.py @@ -0,0 +1,96 @@ +""" +mrcrypt.cli +~~~~~~~~~~~ + +Implements the command-line interface. Is an entry point into the program. +""" +import argparse +import ast +import logging +import sys + +from mrcrypt.cli import commands + + +def _build_encrypt_parser(subparsers): + """Builds the encryption subparser.""" + encrypt_parser = subparsers.add_parser('encrypt', + description='Encrypts a file or directory recursively') + + encrypt_parser.add_argument('-r', '--regions', + nargs='+', + help='A list of regions to encrypt with KMS. End the list with --') + encrypt_parser.add_argument('-e', '--encryption_context', type=ast.literal_eval, + action='store', help='An encryption context to use') + + encrypt_parser.add_argument('key_id', + help='An identifier for a customer master key.') + + encrypt_parser.add_argument('filename', + action='store', + help='The file or directory to encrypt. Use a - to read from ' + 'stdin') + + +def _build_decrypt_parser(subparsers): + """Builds the decryption subparser.""" + decrypt_parser = subparsers.add_parser('decrypt', + description='Decrypts a file') + + decrypt_parser.add_argument('filename', + action='store', + help='The file or directory to decrypt. Use a - to read from ' + 'stdin') + + +def parse_args(args=None): + """Builds the parser and parses the command-line arguments.""" + parser = argparse.ArgumentParser( + description='Multi Region Encryption. A tool for managing secrets across multiple AWS ' + 'regions.') + + parser.add_argument('-p', '--profile', action='store', help='The profile to use') + parser.add_argument('-v', '--verbose', action='count', + help='More verbose output') + parser.add_argument('-o', '--outfile', action='store', help='The file to write the results to') + + subparsers = parser.add_subparsers(dest='command') + + _build_encrypt_parser(subparsers) + _build_decrypt_parser(subparsers) + + return parser.parse_args(args) + + +def _get_logging_level(verbosity_level): + """Sets the logger level from the CLI verbosity argument.""" + if verbosity_level is None: + logging_level = logging.WARN + elif verbosity_level == 1: + logging_level = logging.INFO + else: + logging_level = logging.DEBUG + + return logging_level + + +def parse(): + args = parse_args() + + logging.basicConfig(stream=sys.stderr, level=_get_logging_level(args.verbose)) + + if args.command == 'decrypt': + decrypt_command = commands.DecryptCommand(args.filename, outfile=args.outfile, + profile=args.profile) + decrypt_command.decrypt() + + elif args.command == 'encrypt': + if args.encryption_context is not None and type(args.encryption_context) is not dict: + print('Invalid dictionary in encryption context argument') + sys.exit(1) + + encrypt_command = commands.EncryptCommand(args.filename, args.key_id, + outfile=args.outfile, regions=args.regions, + profile=args.profile, + encryption_context=args.encryption_context) + encrypt_command.encrypt() diff --git a/mrcrypt/crypto/__init__.py b/mrcrypt/crypto/__init__.py new file mode 100644 index 0000000..3eb25b9 --- /dev/null +++ b/mrcrypt/crypto/__init__.py @@ -0,0 +1,2 @@ +from mrcrypt.crypto.decryption import decrypt_message +from mrcrypt.crypto.encryption import encrypt_string diff --git a/mrcrypt/crypto/constants.py b/mrcrypt/crypto/constants.py new file mode 100644 index 0000000..9573016 --- /dev/null +++ b/mrcrypt/crypto/constants.py @@ -0,0 +1,8 @@ +""" +mrcrypt.crypto.constants +~~~~~~~~~~~~~~~~~~~~~~~~ + +Contains constants used by the crypto package. +""" +FRAME_STRING_ID = b'AWSKMSEncryptionClient Frame' +FINAL_FRAME_STRING_ID = b'AWSKMSEncryptionClient Final Frame' diff --git a/mrcrypt/crypto/decryption.py b/mrcrypt/crypto/decryption.py new file mode 100644 index 0000000..afdbc90 --- /dev/null +++ b/mrcrypt/crypto/decryption.py @@ -0,0 +1,140 @@ +""" +mrcrypt.crypto.decryption +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Implements the decryption logic. +""" +import base64 +import logging + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicNumbers +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + +import boto3 + +import mrcrypt.crypto.utils +from mrcrypt import utils +from mrcrypt.crypto.constants import FRAME_STRING_ID, FINAL_FRAME_STRING_ID + + +def decrypt_message(message, profile=None): + """Decrypts the content inside of the provided :class:`mrcrypt.message.Message` object.""" + validate_message_integrity(message, profile=profile) + decryption_handler = DecryptionHandler(message) + return decryption_handler.decrypt_content(profile=profile) + + +class DecryptionHandler(object): + """An object that can decrypt the content contained in a :class:`mrcrypt.message.Message` + object.""" + def __init__(self, message_): + self._message = message_ + + def decrypt_content(self, profile=None): + """Decrypts the content contained by a message, and returns it as a string.""" + key = get_key_from_header(self._message.header, profile=profile) + + content = '' + + for i, frame in enumerate(self._message.body.frames[:-1]): + content += decrypt_framed_content(frame, key, self._message.header.message_id, + is_final_frame=False) + + content += decrypt_framed_content(self._message.body.frames[-1], key, + self._message.header.message_id, is_final_frame=True) + + return content + + +def decrypt_framed_content(frame, key, message_id, is_final_frame): + """Decrypts the content inside ``frame``. + + :param frame: The frame to decrypt. + :param key: The key to decrypt with. + :param message_id: The message ID of the Message that ``frame`` belongs to. + :param is_final_frame: A boolean representing if ``frame`` is the final frame in the message. + + :return: The decrypted frame content. + """ + decryptor = get_decryptor(key, frame.iv, frame.authentication_tag) + + frame_string_id = FINAL_FRAME_STRING_ID if is_final_frame else FRAME_STRING_ID + + content_aad = (message_id + + frame_string_id + + utils.num_to_bytes(frame.sequence_number, 4) + + utils.num_to_bytes(frame.encrypted_content_length, 8)) + + decryptor.authenticate_additional_data(content_aad) + + return decryptor.update(frame.encrypted_content) + decryptor.finalize() + + +def get_decryptor(key, iv, authentication_tag): + """Get a :class:`cryptography.hazmat.primitives.ciphers.CipherContext` object to use for + decryption, configured with ``encryption_key``, ``iv``, and ``authentication_tag``.""" + return Cipher( + algorithm=algorithms.AES(key), + mode=modes.GCM(iv, authentication_tag), + backend=default_backend() + ).decryptor() + + +def get_key_from_header(header, profile=None): + """Attempts to retrieve a data key from the encrypted data keys contained by + ``self.message``.""" + # TODO: Accept a region to use when choosing a key from the command-line + for key in header.encrypted_data_keys: + region = utils.region_from_arn(key.key_provider_info.decode()) + + session = boto3.Session(profile_name=profile) + + client = session.client('kms', region_name=region) + + data_key = client.decrypt( + CiphertextBlob=key.encrypted_data_key, + EncryptionContext=header.encryption_context) + + info = (utils.num_to_bytes(header.algorithm_id, 2) + + header.message_id) + + key = mrcrypt.crypto.utils.derive_hkdf_key(data_key[u'Plaintext'], info) + + return key + + +def validate_message_integrity(message, profile=None): + """Simply calls the two validation methods to validate the header and entire message. + + **NOTE:** The body is validated during the decryption of the body's content.""" + validate_header(message.header, profile=profile) + if message.header.algorithm.trailing_signature_algorithm is not None: + validate_message(message) + + +def validate_header(header, profile=None): + """Validates the header using the header's authentication tag.""" + key = get_key_from_header(header, profile=profile) + decryptor = get_decryptor(key, header.iv, + header.authentication_tag) + + decryptor.authenticate_additional_data(header.serialize_authenticated_fields()) + decryptor.finalize() + + logging.info("Header integrity verified.") + + +def validate_message(message): + """Validates the entire message using the signature found in the message footer.""" + encoded_point = base64.b64decode(message.header.encryption_context['aws-crypto-public-key']) + public_numbers = EllipticCurvePublicNumbers.from_encoded_point(ec.SECP384R1(), encoded_point) + public_key = public_numbers.public_key(default_backend()) + verifier = public_key.verifier(message.footer.signature, ec.ECDSA(hashes.SHA384())) + + verifier.update(message.serialize_authenticated_fields()) + verifier.verify() + + logging.info("Message integrity verified.") diff --git a/mrcrypt/crypto/encryption.py b/mrcrypt/crypto/encryption.py new file mode 100644 index 0000000..02fe62d --- /dev/null +++ b/mrcrypt/crypto/encryption.py @@ -0,0 +1,217 @@ +""" +mrcrypt.crypto.encryption +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Implements the encryption logic. +""" +import base64 +import os + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.primitives.asymmetric import ec + +import mrcrypt.algorithms +import mrcrypt.crypto.utils +from mrcrypt import message, utils +from mrcrypt.crypto.constants import FINAL_FRAME_STRING_ID, FRAME_STRING_ID +from mrcrypt import exceptions + + +def encrypt_string(s, master_key_id, regions=None, profile=None, encryption_context=None): + """Encrypts a string with the given CMK in the regions provided. + + :param s: The string to encrypt. + :param master_key_id: The key id of a CMK (alias, arn, etc.). + :param regions: (optional) A list of regions. + :param profile: (optional) A named profile + :param encryption_context: (optional) An dictionary to use when encrypting. + + :return: A :class:`mrcrypt.message.Message` containing the encrypted string. + """ + regions = [] if regions is None else regions + + kms_clients = utils.get_kms_clients(regions, profile) + + handler = EncryptionHandler(kms_clients, master_key_id) + + return handler.encrypt_string(s, encryption_context) + + +class EncryptionHandler(object): + """An object that can encrypt a string into a :class:`mrcrypt.message.Message` object. + + :param kms_clients: A list of boto3 clients connected to KMS. + :param master_key_id: The ID of a CMK. + :param algorithm: (optional) A :class:`mrcrypt.algorithm.AlgorithmProfile` object. + :param frame_size: (optional) The size of the content in a single frame. + """ + + def __init__(self, kms_clients, master_key_id, + algorithm=mrcrypt.algorithms.default_algorithm(), frame_size=4096): + self.kms_clients = kms_clients + self.master_key_id = master_key_id + self.algorithm = algorithm + self.frame_size = frame_size + + # TODO: Support the other algorithms + if self.algorithm != mrcrypt.algorithms.default_algorithm(): + raise NotImplementedError + + if self.algorithm.trailing_signature_algorithm is not None: + private_key = ec.generate_private_key(ec.SECP384R1(), default_backend()) + + self.signer = private_key.signer(ec.ECDSA(hashes.SHA384())) + + public_curve_point = get_public_compressed_curve_point(private_key) + self.encryption_context = {'aws-crypto-public-key': public_curve_point} + + def encrypt_string(self, s, encryption_context=None): + """Encrypts the given string, returning a :class:`mrcrypt.message.Message` containing that + encrypted string.""" + encryption_context = {} if encryption_context is None else encryption_context + + self.encryption_context.update(encryption_context) + + data_key = self.get_data_key(self.encryption_context) + encrypted_data_keys = self.get_encrypted_data_keys(data_key, self.encryption_context) + + header = message.Header( + version=1, + type_=0x80, + algorithm_id=self.algorithm.id, + message_id=os.urandom(16), + encryption_context=self.encryption_context, + encrypted_data_keys=encrypted_data_keys, + content_type=2, + reserved_field=0, + frame_content_length=self.frame_size, + iv=os.urandom(self.algorithm.iv_length) + ) + + info = utils.num_to_bytes(header.algorithm_id, 2) + header.message_id + + encryption_key = mrcrypt.crypto.utils.derive_hkdf_key(data_key['Plaintext'], info) + + header_authentication_tag = sign_bytes(header.serialize_authenticated_fields(), + encryption_key, header.iv) + + header.authentication_tag = header_authentication_tag + + if header.content_type == 1: + raise NotImplementedError + elif header.content_type == 2: + body = _encrypt_as_framed(s, header, encryption_key, self.frame_size) + else: + raise exceptions.InvalidContentTypeError( + "Header's content type had a value of {}".format(header.content_type)) + + message_bytes = str(header.serialize() + body.serialize()) + + self.signer.update(message_bytes) + + footer = message.Footer(self.signer.finalize()) + + return message.Message(header, body, footer) + + def get_data_key(self, encryption_context=None): + """Requests a data key from KMS with the first KMS client.""" + return self.kms_clients[0].generate_data_key( + KeyId=self.master_key_id, + KeySpec='AES_256', + EncryptionContext=encryption_context) + + def get_encrypted_data_keys(self, data_key, encryption_context): + """Returns a list of data keys, encrypted by KMS in every region listed inside + ``self.region``.""" + encrypted_data_keys = [message.header.EncryptedDataKey(b'aws-kms', + bytes(data_key['KeyId']), + bytes(data_key['CiphertextBlob']))] + + for client in self.kms_clients[1:]: + key = client.encrypt(KeyId=self.master_key_id, + Plaintext=data_key['Plaintext'], + EncryptionContext=encryption_context) + encrypted_data_key = message.header.EncryptedDataKey(b'aws-kms', + bytes(key['KeyId']), + bytes(key['CiphertextBlob'])) + encrypted_data_keys.append(encrypted_data_key) + + return encrypted_data_keys + + +def _encrypt_as_framed(s, header, encryption_key, frame_size=4096): + """Encrypts the provided string into frames of ``frame_size``. + + :param s: The string to encrypt. + :param header: A :class:`mrcrypt.message.Header` object. + :param encryption_key: The key to encrypt the frame content with. + :param frame_size: (optional) The size of the content inside each frame. + + :return: A list of :class:`mrcrypt.message.body.Frame` objects. + """ + frames = [] + + for i, unencrypted_content in enumerate(utils.split(s, frame_size), start=1): + final_frame = len(unencrypted_content) != frame_size + + frame_string_id = FINAL_FRAME_STRING_ID if final_frame else FRAME_STRING_ID + + content_aad = (header.message_id + frame_string_id + utils.num_to_bytes(i, 4) + + utils.num_to_bytes(len(unencrypted_content), 8)) + + frame_iv = os.urandom(header.algorithm.iv_length) + + encryptor = get_encryptor(encryption_key, frame_iv) + + encryptor.authenticate_additional_data(content_aad) + + ciphertext = encryptor.update(unencrypted_content) + encryptor.finalize() + + frame = message.body.Frame( + is_final_frame=final_frame, + sequence_number=i, + iv=frame_iv, + encrypted_content=ciphertext, + authentication_tag=encryptor.tag + ) + + frames.append(frame) + + return message.FrameBody(header, frames) + + +def get_public_compressed_curve_point(private_key): + """Returns a base-64 encoded compressed curve point from an + :class:`cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey` object.""" + encoded_point = private_key.public_key().public_numbers().encode_point() + return base64.b64encode(encoded_point) + + +def sign_bytes(bytes_, encryption_key, iv): + """Generates a authentication tag which can authenticate ``bytes_``. + + :param bytes_: The byte string to sign. + :param encryption_key: The encryption key to use. + :param iv: The initialization vector. + + :return: The authentication tag. + """ + encryptor = get_encryptor(encryption_key, iv) + + encryptor.authenticate_additional_data(bytes_) + + encryptor.finalize() + + return encryptor.tag + + +def get_encryptor(encryption_key, iv): + """Get a :class:`cryptography.hazmat.primitives.ciphers.CipherContext` object to use for + encryption, configured with ``encryption_key`` and a ``iv``.""" + return Cipher( + algorithm=algorithms.AES(encryption_key), + mode=modes.GCM(iv), + backend=default_backend() + ).encryptor() diff --git a/mrcrypt/crypto/utils.py b/mrcrypt/crypto/utils.py new file mode 100644 index 0000000..41bcb4b --- /dev/null +++ b/mrcrypt/crypto/utils.py @@ -0,0 +1,35 @@ +""" +mrcrypt.crypto.utils +~~~~~~~~~~~~~~~~~~~~ + +Contains utility functions used by the crypto package. +""" +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.hkdf import HKDF + + +def derive_hkdf_key(plaintext_data_key, info=None): + """Derive a key from a plaintext data key. + + :param plaintext_data_key: The data key to convert. + :param info: (optional) A byte string containing application specific context information. In + the case of a frame, this should be the algorithm ID and message ID from the + header. + + :return: A key derived from ``plaintext_data_key``. + """ + if len(plaintext_data_key) != 32: + raise ValueError('Expected a key of length 32') + + hkdf = HKDF( + algorithm=hashes.SHA384(), + length=len(plaintext_data_key), + salt=None, + info=info, + backend=default_backend() + ) + + key = hkdf.derive(plaintext_data_key) + + return key diff --git a/mrcrypt/exceptions.py b/mrcrypt/exceptions.py new file mode 100644 index 0000000..1bb9fb7 --- /dev/null +++ b/mrcrypt/exceptions.py @@ -0,0 +1,26 @@ +""" +mrcrypt.exceptions +~~~~~~~~~~~~~~~~~~ + +Contains the exceptions used by mrcrypt. +""" + + +class BadCipherTextError(Exception): + """Raised when a parsed value is invalid.""" + pass + + +class ParseError(Exception): + """Raised when something prevents the parsing from continuing.""" + pass + + +class InvalidContentTypeError(Exception): + """Raised when the header's content type is not valid/supported.""" + pass + + +class UnsupportedFileObject(Exception): + """Raised when the file passed into the command line is not a file.""" + pass diff --git a/mrcrypt/io.py b/mrcrypt/io.py new file mode 100644 index 0000000..6c50fea --- /dev/null +++ b/mrcrypt/io.py @@ -0,0 +1,58 @@ +""" +mrcrypt.io +~~~~~~~~~~ + +Implements IO related tasks. +""" +from __future__ import absolute_import +import io +import tempfile +import os + +from mrcrypt.message import Message +from mrcrypt.exceptions import ParseError +from mrcrypt import utils + + +def parse_message_file(filename): + """Reads ``filename`` into a :class:`mrcrypt.message.Message` object.""" + with io.open(filename, 'rb') as infile: + byte_array = infile.read() + + message = Message() + parsed_bytes = message.deserialize(byte_array, 0) + + if parsed_bytes != len(byte_array): + raise ParseError('Did not parse enough bytes') + + return message + + +def write_message(filename, directory, message, permissions=None): + """Writes ``message`` to ``filename`` in ``directory``. The file created has ``permissions`` + set as its permissions.""" + message_bytes = str(message.serialize()) + write_str(filename, directory, message_bytes, permissions) + + +def read_plaintext_file(filename): + """Reads the contents of ``filename`` and returns it as a string.""" + with io.open(filename, 'rb') as infile: + contents = infile.read() + + return contents + + +def write_str(filename, directory, content, permissions=None): + """Writes ``content`` to ``filename`` in ``directory``. The file created has ``permissions`` + set as it's file permissions. ``content`` is written to a temporary file first, and then that + file is renamed (atomically) into the correct file.""" + if not permissions: + permissions = utils.get_default_file_permissions() + with io.open(tempfile.mkstemp(prefix='.', suffix='.tmp', dir=directory)[1], 'wb+') as outfile: + outfile.write(content) + outfile.flush() + os.fsync(outfile.fileno()) + + os.chmod(outfile.name, permissions) + os.rename(outfile.name, filename) diff --git a/mrcrypt/main.py b/mrcrypt/main.py new file mode 100644 index 0000000..eeb7900 --- /dev/null +++ b/mrcrypt/main.py @@ -0,0 +1,6 @@ +def main(): + from mrcrypt.cli import parser + parser.parse() + +if __name__ == '__main__': + main() diff --git a/mrcrypt/message/__init__.py b/mrcrypt/message/__init__.py new file mode 100644 index 0000000..0c15a0d --- /dev/null +++ b/mrcrypt/message/__init__.py @@ -0,0 +1,4 @@ +from mrcrypt.message.header import Header +from mrcrypt.message.body import FrameBody +from mrcrypt.message.footer import Footer +from mrcrypt.message.message import Message diff --git a/mrcrypt/message/body.py b/mrcrypt/message/body.py new file mode 100644 index 0000000..3c04b45 --- /dev/null +++ b/mrcrypt/message/body.py @@ -0,0 +1,166 @@ +""" +mrcrypt.message.body +~~~~~~~~~~~~~~~~~~~~ + +Implements the objects that represent the body of a message. +""" +from mrcrypt import exceptions, utils + +FINAL_FRAME_SEQUENCE_NUMBER = 0xFFFFFFFF + + +class FrameBody(object): + """Represents the body of a message with Framed Content. + + :param header: The message header. + :param frames: A list of :class:`Frame` objects. + """ + def __init__(self, header=None, frames=None): + self.header = header + self.frames = [] if frames is None else frames + + def deserialize(self, byte_array, off): + """Loads information from ``byte_array`` into this object.""" + complete = False + parsed_bytes = 0 + + while not complete: + current_frame = Frame() + + parsed_bytes += current_frame.deserialize(byte_array, off + parsed_bytes, + self.header.iv_length, + self.header.frame_content_length, + self.header.algorithm.tag_length) + + self.frames.append(current_frame) + + complete = current_frame.is_final_frame + + return parsed_bytes + + def serialize(self): + """Writes this object into a :class:`bytearray`.""" + byte_array = bytearray() + + for frame in self.frames: + byte_array.extend(frame.serialize()) + + return byte_array + + +class Frame(object): + """Represents a single frame of the message body. + + :param is_final_frame: A boolean indicating whether this frame is the final frame or not. + :param sequence_number: Indicates which frame number this is in the list of frames. + :param iv: The IV used to encrypt the content of the frame. + :param encrypted_content: The encrypted bytes. + :param authentication_tag: A tag used to verify the integrity of this frame. + """ + + def __init__(self, is_final_frame=None, sequence_number=None, iv=None, encrypted_content=None, + authentication_tag=None): + self.is_final_frame = is_final_frame + self.sequence_number = sequence_number + self.iv = iv + self.encrypted_content = encrypted_content + self.authentication_tag = authentication_tag + + self._encrypted_content_length = None + self._iv_length = None + self._encrypted_content_length = None + self._authentication_tag_length = None + + @property + def iv_length(self): + return len(self.iv) + + @property + def encrypted_content_length(self): + return len(self.encrypted_content) + + @property + def authentication_tag_length(self): + return len(self.authentication_tag) + + def deserialize(self, byte_array, off, iv_length, encrypted_content_length, + authentication_tag_length): + """Loads information from ``byte_array`` into this object.""" + self._iv_length = iv_length + self._encrypted_content_length = encrypted_content_length + self._authentication_tag_length = authentication_tag_length + + parsed_bytes = self._parse_sequence_number(byte_array, off) + + if self.sequence_number == FINAL_FRAME_SEQUENCE_NUMBER: + self.is_final_frame = True + parsed_bytes += self._parse_sequence_number(byte_array, off + parsed_bytes) + else: + self.is_final_frame = False + + parsed_bytes += self._parse_iv(byte_array, parsed_bytes + off) + + if self.is_final_frame: + parsed_bytes += self._parse_encrypted_content_length(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_encrypted_content(byte_array, off + parsed_bytes) + else: + parsed_bytes += self._parse_encrypted_content(byte_array, off + parsed_bytes) + + parsed_bytes += self._parse_authentication_tag(byte_array, off + parsed_bytes) + + return parsed_bytes + + def serialize(self): + """Writes this object into a :class:`bytearray`.""" + byte_array = bytearray() + + if self.is_final_frame: + byte_array.extend(utils.num_to_bytes(FINAL_FRAME_SEQUENCE_NUMBER, 4)) + + byte_array.extend(utils.num_to_bytes(self.sequence_number, 4)) + byte_array.extend(self.iv) + + if self.is_final_frame: + byte_array.extend(utils.num_to_bytes(self.encrypted_content_length, 4)) + + byte_array.extend(self.encrypted_content) + byte_array.extend(self.authentication_tag) + + return byte_array + + def _parse_sequence_number(self, byte_array, off): + self.sequence_number = utils.bytes_to_int(byte_array[off:off + 4]) + return 4 + + def _parse_iv(self, byte_array, off): + length = len(byte_array) - off + + if length < self._iv_length: + raise exceptions.ParseError('Not enough bytes to parse IV') + + self.iv = byte_array[off:off + self._iv_length] + + return self._iv_length + + def _parse_encrypted_content_length(self, byte_array, off): + self._encrypted_content_length = utils.bytes_to_int(byte_array[off:off + 4]) + + if self._encrypted_content_length < 0: + raise exceptions.BadCipherTextError('Invalid encrypted content length ({})' + .format(self._encrypted_content_length)) + + return 4 + + def _parse_encrypted_content(self, byte_array, off): + self.encrypted_content = byte_array[off:off + self._encrypted_content_length] + return self._encrypted_content_length + + def _parse_authentication_tag(self, byte_array, off): + length = len(byte_array) - off + + if length < self._authentication_tag_length: + raise exceptions.ParseError('Not enough bytes to parse authentication tag') + + self.authentication_tag = byte_array[off:off + self._authentication_tag_length] + + return self._authentication_tag_length diff --git a/mrcrypt/message/footer.py b/mrcrypt/message/footer.py new file mode 100644 index 0000000..4b598ec --- /dev/null +++ b/mrcrypt/message/footer.py @@ -0,0 +1,48 @@ +""" +mrcrypt.message.footer +~~~~~~~~~~~~~~~~~~~~~~ + +Implements the objects that represent the footer of a message. +""" +from mrcrypt import exceptions, utils + + +class Footer(object): + """Represents the footer of a message. The footer only exists if the message uses a signed + algorithm. + + :param signature: The signature used to authenticate the integrity of the message. + """ + + def __init__(self, signature=None): + self.signature = signature + self._signature_length = None + + @property + def signature_length(self): + return len(self.signature) + + def deserialize(self, byte_array, off): + """Load the information from ``byte_array`` into this object.""" + parsed_bytes = 0 + parsed_bytes += self._parse_signature_length(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_signature(byte_array, off + parsed_bytes) + return parsed_bytes + + def serialize(self): + """Write this object into a :class:`bytearray`.""" + byte_array = bytearray() + byte_array.extend(utils.num_to_bytes(self.signature_length, 2)) + byte_array.extend(self.signature) + return byte_array + + def _parse_signature_length(self, byte_array, off): + self._signature_length = utils.bytes_to_int(byte_array[off:off + 2]) + return 2 + + def _parse_signature(self, byte_array, off): + length = len(byte_array) - off + if length < self._signature_length: + raise exceptions.ParseError('Not enough bytes to parse signature') + self.signature = byte_array[off:off + self._signature_length] + return self._signature_length diff --git a/mrcrypt/message/header.py b/mrcrypt/message/header.py new file mode 100644 index 0000000..794a51d --- /dev/null +++ b/mrcrypt/message/header.py @@ -0,0 +1,447 @@ +""" +mrcrypt.message.header +~~~~~~~~~~~~~~~~~~~~~~ + +Implements the objects that represent the header of a message. +""" +from mrcrypt import algorithms, exceptions, utils + + +class Header(object): + """Represents the header of a message. + + :param version: The version of the message. + :param type: The type of the message format. + :param algorithm_id: The id of the algorithm used to encrypt the content of the message. + :param message_id: A random 128-bit value that identifies this message. + :param content_type: The type of encrypted content: non-framed (1) or framed (2). + :param reserved_field: An empty field reserved for future use by AWS. + :param frame_content_length: The length of the encrypted content inside a frame. 0 if + the body is non-framed. + :param iv: The initialization vector used for the header's authentication tag. + :param authentication_tag: A tag used to validate the integrity of the header. + :param encryption_context: A dictionary containing additional authenticated data. + :param encrypted_data_keys: A list of encrypted data keys. The decrypted data key is what is + used to derive the key used for encrypting and decrypting the + message content. + """ + + def __init__(self, version=None, type_=None, algorithm_id=None, message_id=None, + encryption_context=None, encrypted_data_keys=None, content_type=None, + reserved_field=None, frame_content_length=None, iv=None, authentication_tag=None): + self.version = version + self.type = type_ + self.algorithm_id = algorithm_id + self.message_id = message_id + self.content_type = content_type + self.reserved_field = reserved_field + self.frame_content_length = frame_content_length + self.iv = iv + self.authentication_tag = authentication_tag + + self.encryption_context = {} if encryption_context is None else encryption_context + self.encrypted_data_keys = [] if encrypted_data_keys is None else encrypted_data_keys + + # These are only used for validation during deserialization. If this object was not created + # via deserialize(), then they will remain as `None`. + self._encryption_context_length = None + self._encryption_context = None + self._encrypted_data_key_count = None + self._iv_length = None + + @property + def algorithm(self): + return algorithms.algorithm_from_id(self.algorithm_id) + + @property + def encryption_context_length(self): + # Add two for key/value pair count + return utils.dict_to_byte_length(self.encryption_context) + 2 + + @property + def encrypted_data_key_count(self): + return len(self.encrypted_data_keys) + + @property + def iv_length(self): + return len(self.iv) + + def deserialize(self, byte_array, off): + """Loads the information from ``byte_array`` into this object.""" + parsed_bytes = 0 + parsed_bytes += self._parse_version(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_type(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_algorithm_id(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_message_id(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_encryption_context_length(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_encryption_context(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_encrypted_data_key_count(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_encrypted_data_keys(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_content_type(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_reserved_field(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_iv_length(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_frame_content_length(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_iv(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_authentication_tag(byte_array, off + parsed_bytes) + return parsed_bytes + + def serialize(self): + """Writes this object into a byte string.""" + byte_array = bytearray(self.serialize_authenticated_fields()) + + byte_array.extend(self.iv) + byte_array.extend(self.authentication_tag) + + return str(byte_array) + + def serialize_authenticated_fields(self): + """Writes all the authenticated fields into a byte string. The every field but the + authentication tag field is an authenticated field. This function is useful for validating + the header's integrity.""" + byte_array = bytearray() + + byte_array.extend(utils.num_to_bytes(self.version, 1)) + byte_array.extend(utils.num_to_bytes(self.type, 1)) + byte_array.extend(utils.num_to_bytes(self.algorithm_id, 2)) + byte_array.extend(self.message_id.zfill(16)) + + byte_array.extend(utils.num_to_bytes(self.encryption_context_length, 2)) + byte_array.extend(_EncryptionContext.from_dict(self.encryption_context).serialize()) + + byte_array.extend(utils.num_to_bytes(self.encrypted_data_key_count, 2)) + for key in self.encrypted_data_keys: + byte_array.extend(key.serialize()) + + byte_array.extend(utils.num_to_bytes(self.content_type, 1)) + byte_array.extend(utils.num_to_bytes(self.reserved_field, 4)) + byte_array.extend(utils.num_to_bytes(self.iv_length, 1)) + byte_array.extend(utils.num_to_bytes(self.frame_content_length, 4)) + + return str(byte_array) + + def _parse_version(self, byte_array, off): + self.version = utils.bytes_to_int(byte_array[off]) + + if self.version != 1: + raise exceptions.BadCipherTextError('Invalid version number ({})'.format(self.version)) + + return 1 + + def _parse_type(self, byte_array, off): + self.type = utils.bytes_to_int(byte_array[off]) + + if self.type != 0x80: + raise exceptions.BadCipherTextError('Invalid message type ({})'.format(self.type)) + + return 1 + + def _parse_algorithm_id(self, byte_array, off): + self.algorithm_id = utils.bytes_to_int(byte_array[off:off + 2]) + return 2 + + def _parse_message_id(self, byte_array, off): + self.message_id = byte_array[off:off + 16] + return 16 + + def _parse_encryption_context_length(self, byte_array, off): + self._encryption_context_length = utils.bytes_to_int(byte_array[off:off + 2]) + + if self._encryption_context_length < 0: + raise exceptions.BadCipherTextError('Invalid encryption context length ({})' + .format(self._encryption_context_length)) + + return 2 + + def _parse_encryption_context(self, byte_array, off): + length = len(byte_array) - off + + if length < self._encryption_context_length: + raise exceptions.ParseError('Not enough bytes to parse encryption context') + + self._encryption_context = _EncryptionContext() + + parsed_bytes = self._encryption_context.deserialize(byte_array, off) + + self.encryption_context = self._encryption_context.to_dict() + + if parsed_bytes != self._encryption_context_length: + raise exceptions.ParseError('Did not properly parse encryption context') + + return self._encryption_context_length + + def _parse_encrypted_data_key_count(self, byte_array, off): + self._encrypted_data_key_count = utils.bytes_to_int(byte_array[off:off + 2]) + return 2 + + def _parse_encrypted_data_keys(self, byte_array, off): + parsed_bytes = 0 + + for i in xrange(self._encrypted_data_key_count): + data_key = EncryptedDataKey() + parsed_bytes += data_key.deserialize(byte_array, off + parsed_bytes) + self.encrypted_data_keys.append(data_key) + + return parsed_bytes + + def _parse_content_type(self, byte_array, off): + self.content_type = utils.bytes_to_int(byte_array[off]) + + if self.content_type not in (1, 2): + raise exceptions.BadCipherTextError('Invalid content type ({})' + .format(self.content_type)) + + return 1 + + def _parse_reserved_field(self, byte_array, off): + self.reserved_field = utils.bytes_to_int(byte_array[off:off + 4]) + + if self.reserved_field != 0: + raise exceptions.BadCipherTextError('Invalid value for reserved field ({})' + .format(self.reserved_field)) + + return 4 + + def _parse_iv_length(self, byte_array, off): + self._iv_length = utils.bytes_to_int(byte_array[off]) + + if self._iv_length < 0: + raise exceptions.BadCipherTextError('Invalid IV length ({})'.format(self._iv_length)) + + return 1 + + def _parse_frame_content_length(self, byte_array, off): + self.frame_content_length = utils.bytes_to_int(byte_array[off:off + 4]) + + if self.frame_content_length < 0: + raise exceptions.BadCipherTextError('Invalid frame length ({})' + .format(self.frame_content_length)) + + return 4 + + def _parse_iv(self, byte_array, off): + length = len(byte_array) - off + + if length < self._iv_length: + raise exceptions.ParseError('Not enough bytes to parse IV') + + self.iv = byte_array[off:off + self._iv_length] + + return self._iv_length + + def _parse_authentication_tag(self, byte_array, off): + length = len(byte_array) - off + + tag_length = self.algorithm.tag_length + + if length < tag_length: + raise exceptions.ParseError('Not enough bytes to parse authentication tag') + + self.authentication_tag = byte_array[off:off + tag_length] + + return tag_length + + +class _EncryptionContext(object): + """Represents the encryption context. + + This object is simply a helper to deserialize and serialize the encryption context. The true + encryption context exposed in the header is a dict. + + :param context_pairs: A list of :class:`_EncryptionContextKeyValuePair` objects. + """ + + def __init__(self, context_pairs=None): + self.context_pairs = [] if context_pairs is None else context_pairs + self._key_value_pair_count = None + + def deserialize(self, byte_array, off): + """Loads the information from ``byte_array`` into this object.""" + parsed_bytes = 0 + parsed_bytes += self._parse_key_value_pair_count(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_key_value_pairs(byte_array, off + parsed_bytes) + return parsed_bytes + + def serialize(self): + """Writes this object into a :class:`bytearray`.""" + byte_array = bytearray() + + byte_array.extend(utils.num_to_bytes(len(self.context_pairs), 2)) + + for pair in self.context_pairs: + byte_array.extend(pair.serialize()) + + return byte_array + + def _parse_key_value_pair_count(self, byte_array, off): + self._key_value_pair_count = utils.bytes_to_int(byte_array[off:off + 2]) + return 2 + + def _parse_key_value_pairs(self, byte_array, off): + parsed_bytes = 0 + + for i in xrange(self._key_value_pair_count): + pair = _EncryptionContextKeyValuePair() + parsed_bytes += pair.deserialize(byte_array, off + parsed_bytes) + self.context_pairs.append(pair) + + return parsed_bytes + + def to_dict(self): + """Converts this object into a dict.""" + return dict((pair.key, pair.value) for pair in self.context_pairs) + + @classmethod + def from_dict(cls, dict_): + """Creates this object from ``dict_``.""" + pairs = [_EncryptionContextKeyValuePair(k, v) for k, v in dict_.iteritems()] + return cls(pairs) + + +class _EncryptionContextKeyValuePair(object): + """Represents a single key-value pair in the encryption context. + + :param key: The key of the key-value pair. + :param value: The value of the key-value pair. + """ + + def __init__(self, key=None, value=None): + self.key = key + self.value = value + + self._key_length = None + self._value_length = None + + @property + def key_length(self): + return len(self.key) + + @property + def value_length(self): + return len(self.value) + + def deserialize(self, byte_array, off): + """Loads information from ``byte_array`` into this object.""" + parsed_bytes = 0 + parsed_bytes += self._parse_key_length(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_key(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_value_length(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_value(byte_array, off + parsed_bytes) + return parsed_bytes + + def serialize(self): + """Writes this object into a :class:`bytearray`.""" + byte_array = bytearray() + byte_array.extend(utils.num_to_bytes(self.key_length, 2)) + byte_array.extend(self.key) + byte_array.extend(utils.num_to_bytes(self.value_length, 2)) + byte_array.extend(self.value) + return byte_array + + def _parse_key_length(self, byte_array, off): + self._key_length = utils.bytes_to_int(byte_array[off:off + 2]) + return 2 + + def _parse_key(self, byte_array, off): + self.key = byte_array[off:off + self._key_length] + return self._key_length + + def _parse_value_length(self, byte_array, off): + self._value_length = utils.bytes_to_int(byte_array[off:off + 2]) + return 2 + + def _parse_value(self, byte_array, off): + self.value = byte_array[off:off + self._value_length] + return self._value_length + + +class EncryptedDataKey(object): + """Represents an encrypted data key. + + :param key_provider_id: The provider of the key. This tool only supports KMS. + :param key_provider_info: Information about the key provider. + :param encrypted_data_key: The encrypted data key. + """ + def __init__(self, key_provider_id=None, key_provider_info=None, encrypted_data_key=None): + self.key_provider_id = key_provider_id + self.key_provider_info = key_provider_info + self.encrypted_data_key = encrypted_data_key + + self._key_provider_id_length = None + self._key_provider_info_length = None + self._encrypted_data_key_length = None + + @property + def key_provider_id_length(self): + return len(self.key_provider_id) + + @property + def key_provider_info_length(self): + return len(self.key_provider_info) + + @property + def encrypted_data_key_length(self): + return len(self.encrypted_data_key) + + def deserialize(self, byte_array, off): + """Loads the information in ``byte_array`` into this object.""" + parsed_bytes = 0 + parsed_bytes += self._parse_key_provider_id_length(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_key_provider_id(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_key_provider_info_length(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_key_provider_info(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_encrypted_data_key_length(byte_array, off + parsed_bytes) + parsed_bytes += self._parse_encrypted_data_key(byte_array, off + parsed_bytes) + return parsed_bytes + + def serialize(self): + """Writes this object into a :class:`bytearray`.""" + byte_array = bytearray() + byte_array.extend(utils.num_to_bytes(self.key_provider_id_length, 2)) + byte_array.extend(self.key_provider_id) + byte_array.extend(utils.num_to_bytes(self.key_provider_info_length, 2)) + byte_array.extend(self.key_provider_info) + byte_array.extend(utils.num_to_bytes(self.encrypted_data_key_length, 2)) + byte_array.extend(self.encrypted_data_key) + return byte_array + + def _parse_key_provider_id_length(self, byte_array, off): + self._key_provider_id_length = utils.bytes_to_int(byte_array[off:off + 2]) + return 2 + + def _parse_key_provider_id(self, byte_array, off): + length = len(byte_array) - off + + if length < self._key_provider_id_length: + raise exceptions.ParseError('Not enough bytes to parse key provider id') + + self.key_provider_id = byte_array[off:off + self._key_provider_id_length] + + return self._key_provider_id_length + + def _parse_key_provider_info_length(self, byte_array, off): + self._key_provider_info_length = utils.bytes_to_int(byte_array[off:off + 2]) + return 2 + + def _parse_key_provider_info(self, byte_array, off): + length = len(byte_array) - off + + if length < self._key_provider_info_length: + raise exceptions.ParseError('Not enough bytes to parse key provider info') + + self.key_provider_info = byte_array[off:off + self._key_provider_info_length] + + return self._key_provider_info_length + + def _parse_encrypted_data_key_length(self, byte_array, off): + self._encrypted_data_key_length = utils.bytes_to_int(byte_array[off:off + 2]) + return 2 + + def _parse_encrypted_data_key(self, byte_array, off): + length = len(byte_array) - off + + if length < self._encrypted_data_key_length: + raise exceptions.ParseError('Not enough bytes to parse encrypted data key') + + self.encrypted_data_key = byte_array[off:off + self._encrypted_data_key_length] + + return self._encrypted_data_key_length diff --git a/mrcrypt/message/message.py b/mrcrypt/message/message.py new file mode 100644 index 0000000..0ba41ac --- /dev/null +++ b/mrcrypt/message/message.py @@ -0,0 +1,63 @@ +""" +mrcrypt.message.message +~~~~~~~~~~~~~~~~~~~~~~~ + +Contains a single Message object that represents a message. +""" +from mrcrypt import exceptions +from mrcrypt.message import Header, FrameBody, Footer + + +class Message(object): + """Represents the entire message. + + :param header: A message header. + :param body: The message body. + :param footer: The message footer (one may not exist). + """ + + def __init__(self, header=None, body=None, footer=None): + self.header = header + self.body = body + self.footer = footer + + def deserialize(self, byte_array, off): + """Loads information from ``byte_array`` into this object.""" + parsed_bytes = 0 + + self.header = Header() + parsed_bytes += self.header.deserialize(byte_array, off + parsed_bytes) + + if self.header.content_type == 2: + self.body = FrameBody(header=self.header) + parsed_bytes += self.body.deserialize(byte_array, off + parsed_bytes) + else: + raise NotImplementedError('Non-framed content not supported yet') + + if len(byte_array) - parsed_bytes > 0: + self.footer = Footer() + parsed_bytes += self.footer.deserialize(byte_array, off + parsed_bytes) + + if parsed_bytes != len(byte_array): + raise exceptions.ParseError('Did not parse all the bytes') + + return parsed_bytes + + def serialize(self): + """Writes this object to a byte string.""" + byte_array = bytearray(self.serialize_authenticated_fields()) + + if self.footer is not None: + byte_array.extend(self.footer.serialize()) + + return str(byte_array) + + def serialize_authenticated_fields(self): + """Writes the header and the body to a byte string. Useful for validating the message's + integrity.""" + byte_array = bytearray() + + byte_array.extend(self.header.serialize()) + byte_array.extend(self.body.serialize()) + + return str(byte_array) diff --git a/mrcrypt/utils.py b/mrcrypt/utils.py new file mode 100644 index 0000000..6484e32 --- /dev/null +++ b/mrcrypt/utils.py @@ -0,0 +1,111 @@ +""" +mrcrypt.utils +~~~~~~~~~~~~~ + +Contains utility functions used across mrcrypt. +""" +import string +import random +import os + +import boto3 + + +def get_arns(regions, account_id, alias): + """Get a list of Amazon Resource Names (ARNs). + + :param regions: A list of regions. + :param account_id: An Amazon Account ID. + :param alias: The alias of the key on KMS. + + :return: A list of ARNs. + """ + return ['arn:aws:kms:{}:{}:alias/{}'.format(region, account_id, alias) for region in regions] + + +def region_from_arn(arn): + """ + Extracts the region from an ARN. + + :param arn: An ARN. + + :return: The region in ``arn``. + """ + return arn.split(':')[3] + + +def bytes_to_int(byte_array): + """ + Converts a byte string into a number. + + :param byte_array: The byte string to convert. + + :return: The number represented by ``byte_array``. + """ + return int(byte_array.encode('hex'), 16) + + +def num_to_bytes(number, length): + """ + Converts ``number`` to a bytearray of the given length. + + :param number: A positive number. + :param length: The number of bytes the resulting bytearray should be. + + :return: The number as a byte string. + """ + hex_str = format(number, 'x').zfill(length * 2) + return str(bytearray.fromhex(hex_str)) + + +def dict_to_byte_length(dict_): + """Get the byte length of a dictionary. This returns the sum of the length of the key and the + value and 4 bytes to store the length of the key/value pair, for every pair in ``dict_``.""" + return sum(len(key) + len(value) + 4 for key, value in dict_.iteritems()) + + +def split(iterable, size): + """Splits ``iterable`` into parts of size ``size``. + + Wrapped a list inside a tuple because otherwise, for small values, it will return a generator. + """ + return tuple(iterable[x:x + size] for x in xrange(0, len(iterable), size)) + + +def random_string(length): + """Generates a random string.""" + return ''.join(random.choice(string.printable) for __ in xrange(length)) + + +def get_kms_clients(regions=None, profile=None): + """Gets a list of KMS clients for the regions specified. If no regions are specified, a client + is created with the default region.""" + session = boto3.Session(profile_name=profile) + if not regions: + return [session.client('kms')] + else: + return [session.client('kms', region_name=region) for region in regions] + + +def get_file_permissions(filename): + """Gets the permissions of ``filename``.""" + return os.stat(filename).st_mode & 0777 + + +def get_parent_dir_path(filename): + """Gets the absolute path of the parent directory that ``filename`` belongs to.""" + return os.path.dirname(os.path.abspath(filename)) + + +def get_umask(): + """Returns the current umask.""" + current_umask = os.umask(0) + os.umask(current_umask) + + return current_umask + + +def get_default_file_permissions(): + """Returns default file permissions from the current umask.""" + umask = get_umask() + return 0666 & ~umask diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1122bdb --- /dev/null +++ b/requirements.txt @@ -0,0 +1,30 @@ +boto==2.42.0 +boto3==1.3.1 +botocore==1.4.62 +cffi==1.8.3 +click==6.6 +cryptography==1.5.2 +docutils==0.12 +enum34==1.1.6 +Flask==0.11.1 +futures==3.0.5 +httpretty==0.8.10 +idna==2.1 +ipaddress==1.0.17 +itsdangerous==0.24 +Jinja2==2.8 +jmespath==0.9.0 +MarkupSafe==0.23 +-e git+https://github.com/austinmoore-/moto.git@405d8c63b4b735a41aa4938675506fe40517bfa3#egg=moto +pluggy==0.3.1 +py==1.4.31 +pyasn1==0.1.9 +pycparser==2.14 +pytest==2.9.2 +python-dateutil==2.5.3 +requests==2.11.1 +six==1.10.0 +tox==2.3.1 +virtualenv==15.0.3 +Werkzeug==0.11.11 +xmltodict==0.10.2 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..5b9611c --- /dev/null +++ b/setup.py @@ -0,0 +1,43 @@ +from setuptools import setup, find_packages + + +with open('README.rst', 'r') as f: + readme = f.read() + +setup( + name='mrcrypt', + version='1.0.0', + description='A command-line tool that can encrypt/decrypt secrets using envelope encryption ' + 'for use in multiple AWS KMS regions.', + long_description=readme, + + url='https://github.com/aol/mrcrypt', + + license='Apache 2.0', + + classifiers=[ + 'Development Status :: 5 - Production/Stable', + 'Environment :: Console', + 'Intended Audience :: Developers', + 'Intended Audience :: System Administrators', + 'Natural Language :: English', + 'Operating System :: MacOS :: MacOS X', + 'Programming Language :: Python :: 2', + 'Programming Language :: Python :: 2.7', + 'Topic :: Security', + 'Topic :: Security :: Cryptography', + ], + + packages=find_packages(), + + entry_points={ + 'console_scripts': [ + 'mrcrypt=mrcrypt.main:main' + ] + }, + + install_requires=[ + 'boto3>=0.0.17', + 'cryptography>=1.1', + ], +) diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py new file mode 100644 index 0000000..48db0cb --- /dev/null +++ b/tests/test_algorithms.py @@ -0,0 +1,32 @@ +import pytest + +from mrcrypt import algorithms + + +@pytest.mark.parametrize('algorithm_id, expected_algorithm', [ + (0x0014, algorithms.alg_aes_128_gcm_iv12_tag16_no_kdf), + (0x0046, algorithms.alg_aes_192_gcm_iv12_tag16_no_kdf), + (0x0078, algorithms.alg_aes_256_gcm_iv12_tag16_no_kdf), + (0x0114, algorithms.alg_aes_128_gcm_iv12_tag16_hkdf_sha256), + (0x0146, algorithms.alg_aes_192_gcm_iv12_tag16_hkdf_sha256), + (0x0178, algorithms.alg_aes_256_gcm_iv12_tag16_hkdf_sha256), + (0x0214, algorithms.alg_aes_128_gcm_iv12_tag16_hkdf_sha256_ecdsa_p256), + (0x0346, algorithms.alg_aes_192_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384), + (0x0378, algorithms.alg_aes_256_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384), +]) +def test_algorithm_from_id(algorithm_id, expected_algorithm): + algorithm = algorithms.algorithm_from_id(algorithm_id) + + assert algorithm == expected_algorithm + + +@pytest.mark.parametrize('algorithm_id', [0x0000]) +def test_algorithm_from_id__invalid_id(algorithm_id): + with pytest.raises(ValueError): + algorithms.algorithm_from_id(algorithm_id) + + +def test_default_algorithm(): + result = algorithms.default_algorithm() + + assert isinstance(result, algorithms.AlgorithmProfile) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..f30cea5 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,191 @@ +import logging +import os +import stat + +import pytest +import moto +import boto3 + +from mrcrypt.cli import commands, parser + +SECRET = 'my secret' + + +def test_encrypt_arg__all(): + arguments = ('--profile default -vv --outfile outfile.txt encrypt ' + '--encryption_context {"1":"1"} --regions us-east-1 -- ' + 'alias/test-key secrets.txt') + args = parser.parse_args(arguments.split()) + + assert args.profile == 'default' + assert args.verbose == 2 + assert args.encryption_context == {"1": "1"} + assert args.regions == ['us-east-1'] + assert args.filename == 'secrets.txt' + assert args.outfile == 'outfile.txt' + assert args.key_id == 'alias/test-key' + assert args.command == 'encrypt' + + +def test_encrypt_arg__multiple_regions(): + arguments = 'encrypt --regions us-east-1 us-west-2 eu-west-1 -- alias/test-key secrets.txt' + args = parser.parse_args(arguments.split()) + + assert args.command == 'encrypt' + assert args.key_id == 'alias/test-key' + assert args.filename == 'secrets.txt' + assert args.regions == ['us-east-1', 'us-west-2', 'eu-west-1'] + assert args.verbose == None + assert args.encryption_context is None + assert args.profile is None + assert args.outfile is None + + +def test_encrypt_arg__minimum_args(): + arguments = 'encrypt alias/test-key secrets.txt' + args = parser.parse_args(arguments.split()) + + assert args.command == 'encrypt' + assert args.key_id == 'alias/test-key' + assert args.filename == 'secrets.txt' + assert args.verbose == None + assert args.profile is None + assert args.encryption_context is None + assert args.regions is None + assert args.outfile is None + + +def test_decrypt_arg__all(): + arguments = '--profile default --outfile outfile.txt -vv decrypt secrets.txt' + args = parser.parse_args(arguments.split()) + + assert args.command == 'decrypt' + assert args.filename == 'secrets.txt' + assert args.outfile == 'outfile.txt' + assert args.profile == 'default' + assert args.verbose == 2 + + +def test_decrypt_arg__minimum(): + arguments = 'decrypt secrets.txt' + args = parser.parse_args(arguments.split()) + + assert args.command == 'decrypt' + assert args.filename == 'secrets.txt' + assert args.verbose == None + assert args.profile is None + assert args.outfile is None + + +@pytest.mark.parametrize('infile, outfile, expected', ( + ('secrets.txt', None, 'secrets.txt.decrypted'), + ('secrets.txt.encrypted', None, 'secrets.txt'), + ('secrets.txt.encrypted', 'decrypted.txt', 'decrypted.txt'), + ('secrets.properties.encrypted', None, 'secrets.properties'), + ('secrets.encrypted', None, 'secrets'), +)) +def test_generate_decrypt_filename(infile, outfile, expected): + decrypt_command = commands.DecryptCommand(infile, outfile=outfile) + assert decrypt_command._generate_outfile(infile) == expected + + +@pytest.mark.parametrize('infile, outfile, expected', ( + ('secrets.txt', None, 'secrets.txt.encrypted'), + ('secrets.txt', 'encrypted.txt', 'encrypted.txt'), +)) +def test_generate_encrypt_filename(infile, outfile, expected): + encrypt_command = commands.EncryptCommand(infile, None, outfile=outfile) + assert encrypt_command._generate_outfile(infile) == expected + + +@moto.mock_kms +def test_cli__encrypt_decrypt_flow(setup_files_tuple, kms_master_key_arn): + secrets_file, encrypted_file, decrypted_file = setup_files_tuple + + with open(secrets_file, 'w') as f: + f.write(SECRET) + + encrypt_command = commands.EncryptCommand(secrets_file, kms_master_key_arn, + outfile=encrypted_file) + encrypt_command.encrypt() + + decrypt_command = commands.DecryptCommand(encrypted_file, outfile=decrypted_file) + decrypt_command.decrypt() + + with open(decrypted_file, 'r') as f: + assert f.read() == SECRET + + assert stat.S_IRUSR == os.stat(decrypted_file).st_mode & 0777 + + +@moto.mock_kms +def test_cli__encrypt_decrypt_directory_flow(secrets_dir, kms_master_key_arn): + encrypt_command = commands.EncryptCommand(secrets_dir, kms_master_key_arn) + encrypt_command.encrypt() + + assert os.path.isfile(os.path.join(secrets_dir, 'secrets-1.txt.encrypted')) + assert os.path.isfile(os.path.join(secrets_dir, 'secrets-2.txt.encrypted')) + + os.remove(os.path.join(secrets_dir, 'secrets-1.txt')) + os.remove(os.path.join(secrets_dir, 'secrets-2.txt')) + + decrypt_command = commands.DecryptCommand(secrets_dir) + decrypt_command.decrypt() + + with open(os.path.join(secrets_dir, 'secrets-1.txt')) as f: + assert f.read() == SECRET + + with open(os.path.join(secrets_dir, 'secrets-2.txt')) as f: + assert f.read() == SECRET + + +@pytest.mark.parametrize('verbosity_level, expected_level', ( + (None, logging.WARN), + (1, logging.INFO), + (2, logging.DEBUG), + (10, logging.DEBUG), +)) +def test_set_logging_level(verbosity_level, expected_level): + assert expected_level == parser._get_logging_level(verbosity_level) + + +@pytest.fixture +def setup_files_tuple(tmpdir): + secrets_file = tmpdir.join('secrets.txt') + secrets_file.ensure(file=True) + secrets_file_path = str(secrets_file) + + encrypted_file = tmpdir.join(secrets_file_path + '.encrypted') + encrypted_file.ensure(file=True) + encrypted_file_path = str(encrypted_file) + + decrypted_file = tmpdir.join('decrypted.txt') + decrypted_file.ensure(file=True) + decrypted_file_path = str(decrypted_file) + + return secrets_file_path, encrypted_file_path, decrypted_file_path + + +@pytest.fixture +def secrets_dir(tmpdir): + secrets_file_one = tmpdir.join('secrets-1.txt') + secrets_file_one.ensure(file=True) + + secrets_file_two = tmpdir.join('secrets-2.txt') + secrets_file_two.ensure(file=True) + + with open(str(secrets_file_one), 'w') as f: + f.write(SECRET) + + with open(str(secrets_file_two), 'w') as f: + f.write(SECRET) + + return str(tmpdir) + + +@pytest.fixture +@moto.mock_kms +def kms_master_key_arn(): + client = boto3.client('kms') + response = client.create_key() + return response['KeyMetadata']['Arn'] diff --git a/tests/test_crypto.py b/tests/test_crypto.py new file mode 100644 index 0000000..692271b --- /dev/null +++ b/tests/test_crypto.py @@ -0,0 +1,111 @@ +import pytest +import boto3 +import moto +from cryptography.exceptions import InvalidTag, InvalidSignature + +from mrcrypt import message, utils +from mrcrypt.crypto.encryption import encrypt_string +from mrcrypt.crypto.decryption import decrypt_message + + +@pytest.mark.parametrize('regions', [ + [], + ['us-east-1'], + ['us-east-1', 'us-west-2'], + ['us-east-1', 'us-west-2', 'eu-west-1'] +]) +@moto.mock_kms +def test_encrypt_and_decrypt(kms_key_id, regions): + plaintext = 'test string' + msg = encrypt_string(plaintext, kms_key_id, regions) + + result = decrypt_message(msg) + + assert result == plaintext + + +@pytest.mark.parametrize('regions', [ + [], + ['us-east-1'], + ['us-east-1', 'us-west-2'], + ['us-east-1', 'us-west-2', 'eu-west-1'] +]) +@moto.mock_kms +def test_encrypt_decrypt__large_content(kms_key_id, regions): + plaintext = utils.random_string(5000) + msg = encrypt_string(plaintext, kms_key_id, regions) + + assert isinstance(msg.body, message.FrameBody) + assert len(msg.body.frames) == 2 + + result = decrypt_message(msg) + + assert result == plaintext + + +@pytest.mark.parametrize('regions', [ + [], + ['us-east-1'], + ['us-east-1', 'us-west-2'], + ['us-east-1', 'us-west-2', 'eu-west-1'] +]) +@moto.mock_kms +def test_encrypt_decrypt__encryption_context(kms_key_id, regions): + plaintext = 'test string' + encryption_context = {'test_key': 'test_value'} + msg = encrypt_string(plaintext, kms_key_id, regions, encryption_context=encryption_context) + + byte_array = msg.serialize() + + assert 'test_key' in byte_array + assert 'test_value' in byte_array + + result = decrypt_message(msg) + + assert result == plaintext + + +@moto.mock_kms +def test_header_integrity_check(kms_key_id): + plaintext = 'test string' + msg = encrypt_string(plaintext, kms_key_id, ['us-east-1']) + + byte_array = bytearray(msg.serialize()) + + byte_array[4] = ~byte_array[4] & 0xff # ensure the number isn't the same + + msg = message.Message() + msg.deserialize(str(byte_array), 0) + + with pytest.raises(InvalidTag): + decrypt_message(msg) + + +@moto.mock_kms +def test_message_integrity_check(kms_key_id): + plaintext = 'test string' + msg = encrypt_string(plaintext, kms_key_id, ['us-east-1']) + + byte_array = bytearray(msg.serialize()) + authenticated_fields_byte_array = bytearray(msg.serialize_authenticated_fields()) + + authenticated_fields_byte_array[-1] = ~authenticated_fields_byte_array[-1] & 0xff + + byte_array[:len(authenticated_fields_byte_array)] = authenticated_fields_byte_array + + msg = message.Message() + msg.deserialize(str(byte_array), 0) + + with pytest.raises(InvalidSignature): + decrypt_message(msg) + + +@pytest.fixture +@moto.mock_kms +def kms_key_id(): + client = boto3.client('kms') + + response = client.create_key() + + return response['KeyMetadata']['Arn'] + diff --git a/tests/test_io.py b/tests/test_io.py new file mode 100644 index 0000000..c05e7ac --- /dev/null +++ b/tests/test_io.py @@ -0,0 +1,22 @@ +import os +import stat + +import pytest + +import mrcrypt.io + +FILE_CONTENTS = 'this is a string' + + +@pytest.mark.parametrize('permissions_octal', ( + 0100600, + 0100660, + 0100666 +)) +def test_write_str__file_permissions(tmpdir, permissions_octal): + filename = tmpdir.join('test.txt').strpath + mrcrypt.io.write_str(filename, '', FILE_CONTENTS, permissions_octal) + + assert os.stat(filename)[stat.ST_MODE] == permissions_octal + with open(filename) as f: + assert f.read() == FILE_CONTENTS diff --git a/tests/test_message.py b/tests/test_message.py new file mode 100644 index 0000000..55e13e8 --- /dev/null +++ b/tests/test_message.py @@ -0,0 +1,348 @@ +import struct +from collections import OrderedDict + +import pytest + +from mrcrypt import message + +ONE_AS_ONE_BYTE = struct.pack('>B', 0x01) +ONE_AS_TWO_BYTES = struct.pack('>H', 0x0001) +ONE_AS_FOUR_BYTES = struct.pack('>L', 0x0000001) +ONE_AS_EIGHT_BYTES = struct.pack('>Q', 0x0000000000000001) +SIXTEEN_BYTES = struct.pack('>Q', 0x0000000000000000) + struct.pack('>Q', 0x0000000000000001) + + +def test_header__deserialize(header_bytes): + header = message.Header() + parsed_bytes = header.deserialize(header_bytes, 0) + + assert parsed_bytes == len(header_bytes) + + assert header.version == 1 + assert header.type == 0x80 + assert header.algorithm_id == 0x0014 + + assert header.message_id == SIXTEEN_BYTES + + assert header.encryption_context_length == 8 + assert len(header.encryption_context) == 1 + + context_pair = header.encryption_context.items()[0] + assert context_pair == (ONE_AS_ONE_BYTE, ONE_AS_ONE_BYTE) + + assert len(header.encrypted_data_keys) == 1 + encrypted_data_key = header.encrypted_data_keys[0] + + assert encrypted_data_key.key_provider_id_length == 1 + assert encrypted_data_key.key_provider_id == ONE_AS_ONE_BYTE + assert encrypted_data_key.key_provider_info_length == 1 + assert encrypted_data_key.key_provider_info == ONE_AS_ONE_BYTE + assert encrypted_data_key.encrypted_data_key_length == 1 + assert encrypted_data_key.encrypted_data_key == ONE_AS_ONE_BYTE + + assert header.content_type == 1 + assert header.reserved_field == 0x00000000 + assert header.iv_length == 1 + assert header.frame_content_length == 1 + assert header.iv == ONE_AS_ONE_BYTE + + assert header.authentication_tag == SIXTEEN_BYTES + + +def test_header__deserialize_not_enough_bytes(): + with pytest.raises(IndexError): + header = message.Header() + header.deserialize(struct.pack('>B', 1), 0) + + +def test_header__serialize(header_bytes): + encrypted_data_key = message.header.EncryptedDataKey( + key_provider_id=ONE_AS_ONE_BYTE, + key_provider_info=ONE_AS_ONE_BYTE, + encrypted_data_key=ONE_AS_ONE_BYTE, + ) + + header = message.Header( + version=1, + type_=0x80, + algorithm_id=0x0014, + message_id=SIXTEEN_BYTES, + encryption_context={ONE_AS_ONE_BYTE: ONE_AS_ONE_BYTE}, + encrypted_data_keys=[encrypted_data_key], + content_type=1, + reserved_field=0, + frame_content_length=1, + iv=ONE_AS_ONE_BYTE, + authentication_tag=SIXTEEN_BYTES, + ) + + serialized_bytes = header.serialize() + + assert serialized_bytes == header_bytes + + +def test_header__serialize_multiple_encryption_contexts(header_bytes_multiple_encryption_contexts): + encrypted_data_key = message.header.EncryptedDataKey( + key_provider_id=ONE_AS_ONE_BYTE, + key_provider_info=ONE_AS_ONE_BYTE, + encrypted_data_key=ONE_AS_ONE_BYTE, + ) + + # dict's are unordered, which can mess up the tests. + encryption_context = OrderedDict() + encryption_context[ONE_AS_ONE_BYTE] = ONE_AS_ONE_BYTE + encryption_context[ONE_AS_TWO_BYTES] = ONE_AS_TWO_BYTES + + header = message.Header( + version=1, + type_=0x80, + algorithm_id=0x0014, + message_id=SIXTEEN_BYTES, + encryption_context=encryption_context, + encrypted_data_keys=[encrypted_data_key], + content_type=1, + reserved_field=0, + frame_content_length=1, + iv=ONE_AS_ONE_BYTE, + authentication_tag=SIXTEEN_BYTES, + ) + + serialized_bytes = header.serialize() + + assert serialized_bytes == header_bytes_multiple_encryption_contexts + + +def test_frame_body__deserialize(message_header, frame_body_bytes): + body = message.FrameBody(header=message_header) + parsed_bytes = body.deserialize(frame_body_bytes, 0) + + assert parsed_bytes == len(frame_body_bytes) + + assert len(body.frames) == 1 + + frame = body.frames[0] + + assert frame.is_final_frame == True + assert frame.sequence_number == 1 + + assert frame.iv_length == 12 + assert frame.iv == ONE_AS_ONE_BYTE * 12 + + assert frame.encrypted_content_length == 1 + assert frame.encrypted_content == ONE_AS_ONE_BYTE + + assert frame.authentication_tag_length == 16 + assert frame.authentication_tag == SIXTEEN_BYTES + + +def test_single_frame_body__serialize(message_header): + iv = ONE_AS_ONE_BYTE * 12 + encrypted_content = ONE_AS_ONE_BYTE + auth_tag = SIXTEEN_BYTES + + frame = message.body.Frame(True, 1, iv, encrypted_content, auth_tag) + frame_body = message.FrameBody(header=message_header, frames=[frame]) + + result = frame_body.serialize() + + expected_bytes = (struct.pack('>L', 0xFFFFFFFF) + ONE_AS_FOUR_BYTES + iv + ONE_AS_FOUR_BYTES + + encrypted_content + auth_tag) + + assert result == expected_bytes + + +@pytest.mark.parametrize("num_frames", [2, 3, 4, 8, 16]) +def test_multi_frame_body__serialize(num_frames, message_header): + frames = [] + + expected_bytes = "" + + iv = ONE_AS_ONE_BYTE * 12 + auth_tag = SIXTEEN_BYTES + + for i in xrange(1, num_frames): + encrypted_content = ONE_AS_ONE_BYTE * 4096 + + frames.append(message.body.Frame(False, i, iv, encrypted_content, auth_tag)) + + expected_bytes += struct.pack('>L', i) + iv + encrypted_content + auth_tag + + encrypted_content = ONE_AS_ONE_BYTE + + frames.append(message.body.Frame(True, num_frames, iv, encrypted_content, auth_tag)) + + expected_bytes += (struct.pack('>L', 0xFFFFFFFF) + struct.pack('>L', num_frames) + iv + + ONE_AS_FOUR_BYTES + encrypted_content + auth_tag) + + body = message.FrameBody(header=message_header, frames=frames) + + result = body.serialize() + + assert result == expected_bytes + + +def test_footer__deserialize(): + byte_array = ONE_AS_TWO_BYTES + ONE_AS_ONE_BYTE + + footer = message.Footer() + + footer.deserialize(byte_array, 0) + + assert footer.signature == ONE_AS_ONE_BYTE + assert footer.signature_length == 1 + + +def test_footer__serialize(): + footer = message.Footer(signature=ONE_AS_ONE_BYTE) + + result = footer.serialize() + + expected = ONE_AS_TWO_BYTES + ONE_AS_ONE_BYTE + + assert result == expected + + +@pytest.fixture +def header_bytes(): + """ + :func:`mrcrypt.utils.num_to_bytes` wasn't used because this function shouldn't be dependent on + the correctness of it. + """ + version = struct.pack('>B', 0x01) + type_ = struct.pack('>B', 0x80) + algorithm_id = struct.pack('>H', 0x0014) + + message_id = SIXTEEN_BYTES + + encryption_context_pair_count = ONE_AS_TWO_BYTES + encryption_key_length = ONE_AS_TWO_BYTES + encryption_context_key = ONE_AS_ONE_BYTE + encryption_context_value_length = ONE_AS_TWO_BYTES + encryption_context_value = ONE_AS_ONE_BYTE + + encryption_context = (encryption_context_pair_count + encryption_key_length + + encryption_context_key + encryption_context_value_length + + encryption_context_value) + encryption_context_length = struct.pack('>H', len(encryption_context)) + + encrypted_data_key_count = ONE_AS_TWO_BYTES + + key_provider_id_length = ONE_AS_TWO_BYTES + key_provider_id = ONE_AS_ONE_BYTE + key_provider_info_length = ONE_AS_TWO_BYTES + key_provider_info = ONE_AS_ONE_BYTE + encrypted_key_length = ONE_AS_TWO_BYTES + encrypted_key = ONE_AS_ONE_BYTE + + encrypted_data_key = (key_provider_id_length + key_provider_id + key_provider_info_length + + key_provider_info + encrypted_key_length + encrypted_key) + + content_type = ONE_AS_ONE_BYTE + reserved = struct.pack('>L', 0x00000000) + iv_length = ONE_AS_ONE_BYTE + frame_length = ONE_AS_FOUR_BYTES + iv = ONE_AS_ONE_BYTE + + auth_tag = SIXTEEN_BYTES + + header_bytes_ = (version + type_ + algorithm_id + message_id + encryption_context_length + + encryption_context + encrypted_data_key_count + encrypted_data_key + + content_type + reserved + iv_length + frame_length + iv + auth_tag) + + return header_bytes_ + + +@pytest.fixture +def header_bytes_multiple_encryption_contexts(): + version = struct.pack('>B', 0x01) + type_ = struct.pack('>B', 0x80) + algorithm_id = struct.pack('>H', 0x0014) + + message_id = SIXTEEN_BYTES + + encryption_context_pair_count = struct.pack('>H', 0x0002) + encryption_context_key_length_one = ONE_AS_TWO_BYTES + encryption_context_key_one = ONE_AS_ONE_BYTE + encryption_context_value_length_one = ONE_AS_TWO_BYTES + encryption_context_value_one = ONE_AS_ONE_BYTE + + encryption_context_key_length_two = struct.pack('>H', 0x0002) + encryption_context_key_two = ONE_AS_TWO_BYTES + encryption_context_value_length_two = struct.pack('>H', 0x0002) + encryption_context_value_two = ONE_AS_TWO_BYTES + + encryption_context = (encryption_context_pair_count + encryption_context_key_length_one + + encryption_context_key_one + encryption_context_value_length_one + + encryption_context_value_one + encryption_context_key_length_two + + encryption_context_key_two + encryption_context_value_length_two + + encryption_context_value_two) + encryption_context_length = struct.pack('>H', len(encryption_context)) + + encrypted_data_key_count = ONE_AS_TWO_BYTES + + key_provider_id_length = ONE_AS_TWO_BYTES + key_provider_id = ONE_AS_ONE_BYTE + key_provider_info_length = ONE_AS_TWO_BYTES + key_provider_info = ONE_AS_ONE_BYTE + encrypted_key_length = ONE_AS_TWO_BYTES + encrypted_key = ONE_AS_ONE_BYTE + + encrypted_data_key = (key_provider_id_length + key_provider_id + key_provider_info_length + + key_provider_info + encrypted_key_length + encrypted_key) + + content_type = ONE_AS_ONE_BYTE + reserved = struct.pack('>L', 0x00000000) + iv_length = ONE_AS_ONE_BYTE + frame_length = ONE_AS_FOUR_BYTES + iv = ONE_AS_ONE_BYTE + + auth_tag = SIXTEEN_BYTES + + header_bytes_ = (version + type_ + algorithm_id + message_id + encryption_context_length + + encryption_context + encrypted_data_key_count + encrypted_data_key + + content_type + reserved + iv_length + frame_length + iv + auth_tag) + + return header_bytes_ + + +@pytest.fixture(params=[0x0014, 0x0046, 0x0078, 0x0114, 0x0146, 0x0178, 0x0214, 0x0346, 0x0378]) +def message_header(request): + algorithm_id = request.param + + encrypted_data_key = message.header.EncryptedDataKey( + key_provider_id=ONE_AS_ONE_BYTE, + key_provider_info=ONE_AS_ONE_BYTE, + encrypted_data_key=ONE_AS_ONE_BYTE, + ) + + header = message.Header( + version=1, + type_=0x80, + algorithm_id=algorithm_id, + message_id=SIXTEEN_BYTES, + encryption_context={ONE_AS_ONE_BYTE: ONE_AS_ONE_BYTE}, + encrypted_data_keys=[encrypted_data_key], + content_type=1, + reserved_field=0, + frame_content_length=1, + iv=ONE_AS_ONE_BYTE * 12, + authentication_tag=SIXTEEN_BYTES, + ) + + return header + + +@pytest.fixture +def frame_body_bytes(): + sequence_number_end = struct.pack('>L', 0xFFFFFFFF) + sequence_number = ONE_AS_FOUR_BYTES + + iv = ONE_AS_ONE_BYTE * 12 + + encrypted_content_length = ONE_AS_FOUR_BYTES + encrypted_content = ONE_AS_ONE_BYTE + + authentication_tag = SIXTEEN_BYTES + + return (sequence_number_end + sequence_number + iv + encrypted_content_length + + encrypted_content + authentication_tag) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..e3cd35a --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,91 @@ +import struct +import os + +import pytest + +from mrcrypt import utils + + +def test_get_arns(): + arns = utils.get_arns(['us-east-1', 'us-west-2', 'eu-west-1'], 1234567890, 'test-alias') + + expected_arns = ['arn:aws:kms:us-east-1:1234567890:alias/test-alias', + 'arn:aws:kms:us-west-2:1234567890:alias/test-alias', + 'arn:aws:kms:eu-west-1:1234567890:alias/test-alias'] + + assert expected_arns == arns + + +@pytest.mark.parametrize('arn, expected_region', ( + ('arn:aws:kms:us-east-1:1234567890:alias/test-alias', 'us-east-1'), + ('arn:aws:ec2:us-west-2:1234567890:key/1234567890', 'us-west-2'), +)) +def test_region_from_arn(arn, expected_region): + region = utils.region_from_arn(arn) + assert region == expected_region + + +@pytest.mark.parametrize('number, length, expected', ( + (1, 1, struct.pack('>B', 1)), + (1, 8, struct.pack('>Q', 1)), + (2 ** 16 - 1, 2, struct.pack('>H', 2 ** 16 - 1)) +)) +def test_num_to_bytes(number, length, expected): + result = utils.num_to_bytes(number, length) + assert result == expected + + +@pytest.mark.parametrize('byte_string, expected_int', ( + ('0', 48), + ('1', 49), + ('10', 12592), + ('00', 12336) +)) +def test_bytes_to_int(byte_string, expected_int): + result = utils.bytes_to_int(byte_string) + + assert result == expected_int + + +@pytest.mark.parametrize('dict_, expected_length', ( + # Add 4 because of the two 2-byte fields for key/value length + ({'1': '2'}, 2 + 4), + ({'12': '34'}, 4 + 4) +)) +def test_dict_to_byte_length(dict_, expected_length): + result = utils.dict_to_byte_length(dict_) + + assert result == expected_length + + +@pytest.mark.parametrize('str_, size, expected', ( + ('11', 1, ('1', '1')), + ('splitme', 2, ('sp', 'li', 'tm', 'e')) +)) +def test_split(str_, size, expected): + result = utils.split(str_, size) + + assert result == expected + + +@pytest.mark.parametrize('size', ( + 0, + 1, + 100, + 5000 +)) +def test_random_string(size): + assert len(utils.random_string(size)) == size + + +@pytest.mark.parametrize('permissions_octal', ( + 0600, + 0660, + 0666 +)) +def test_get_file_permissions(tmpdir, permissions_octal): + f = tmpdir.join('test.txt') + f.write('test') + os.chmod(f.strpath, permissions_octal) + + assert utils.get_file_permissions(f.strpath) == permissions_octal diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..41a0613 --- /dev/null +++ b/tox.ini @@ -0,0 +1,10 @@ +[tox] +envlist = py27 + +[testenv] +setenv = + AWS_DEFAULT_REGION = us-east-1 +commands = py.test [] +deps = + pytest + git+https://github.com/austinmoore-/moto.git@405d8c63b4b735a41aa4938675506fe40517bfa3#egg=moto