From 47c863276efbebf86ff1a933bf8e9c7092b4121e Mon Sep 17 00:00:00 2001 From: Mansi Sharma Date: Thu, 12 Jan 2023 13:49:31 -0800 Subject: [PATCH] parameterize cert and key locations --- .github/workflows/pki-certs-location.yml | 34 +++ .github/workflows/pki.yml | 16 - openfl/interface/aggregator.py | 219 +++++++++----- openfl/interface/collaborator.py | 367 +++++++++++++++-------- openfl/interface/workspace.py | 97 ++++-- tests/github/test_pki_cert_location.py | 32 +- tests/github/utils.py | 33 +- 7 files changed, 535 insertions(+), 263 deletions(-) create mode 100644 .github/workflows/pki-certs-location.yml diff --git a/.github/workflows/pki-certs-location.yml b/.github/workflows/pki-certs-location.yml new file mode 100644 index 00000000000..d24be212a24 --- /dev/null +++ b/.github/workflows/pki-certs-location.yml @@ -0,0 +1,34 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Test PKI certs location + +on: + push: + branches: [ develop ] + pull_request: + branches: [ develop ] + +permissions: + contents: read + +jobs: + build: + strategy: + matrix: + os: ['ubuntu-latest', 'windows-latest'] + runs-on: ${{ matrix.os }} + + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.8 + uses: actions/setup-python@v3 + with: + python-version: "3.8" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install . + - name: Test PKI certs location + run: | + python -m tests.github.test_pki_cert_location --agg-cert-path ~/.openfl/agg --agg-key-path ~/.openfl/agg --col1-cert-path ~/.openfl/col1/ --col1-key-path ~/.openfl/col1 --col2-cert-path ~/.openfl/col2 --col2-key-path ~/.openfl/col2/ diff --git a/.github/workflows/pki.yml b/.github/workflows/pki.yml index 224c0d40c0e..2bea96efe72 100644 --- a/.github/workflows/pki.yml +++ b/.github/workflows/pki.yml @@ -47,20 +47,4 @@ jobs: - name: Test PKI run: | python tests/github/pki_wrong_cn.py - test_pki_cert_location: - - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.8 - uses: actions/setup-python@v3 - with: - python-version: "3.8" - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install . - - name: Test PKI - run: | - python tests/github/test_pki_cert_location.py diff --git a/openfl/interface/aggregator.py b/openfl/interface/aggregator.py index 618adae2558..7358e5efbdc 100644 --- a/openfl/interface/aggregator.py +++ b/openfl/interface/aggregator.py @@ -37,8 +37,10 @@ def aggregator(context): @option('-s', '--secure', required=False, help='Enable Intel SGX Enclave', is_flag=True, default=False) @option('-c', '--cert_path', - help='The cert path where pki certs will reside', required=False) -def start_(plan, authorized_cols, secure, cert_path): + help='The path where aggregator certificate resides', required=False) +@option('-k', '--key_path', + help='The path where aggregator key resides', required=False) +def start_(plan, authorized_cols, secure, cert_path, key_path): """Start the aggregator service.""" from pathlib import Path @@ -56,19 +58,18 @@ def start_(plan, authorized_cols, secure, cert_path): logger.info('🧿 Starting the Aggregator Service.') - if cert_path: + if cert_path and key_path: cert_path = Path(cert_path).absolute() - (cert_path / 'cert').mkdir(parents=True, exist_ok=True) - cert_dir_path = cert_path / 'cert' - if not Path(cert_dir_path).exists(): + key_path = Path(key_path).absolute() + if not Path(cert_path).exists() or not Path(key_path).exists(): echo(style('Certificate Path not found.', fg='red') + ' Please run `fx aggregator generate-cert-request --cert_path`' ' to generate certs under this directory first.') common_name = plan.config['network']['settings']['agg_addr'].lower() - plan.get_server(root_certificate=f'{cert_dir_path}/cert_chain.crt', - private_key=f'{cert_dir_path}/server/agg_{common_name}.key', - certificate=f'{cert_dir_path}/server/agg_{common_name}.crt').serve() + plan.get_server(root_certificate=f'{cert_path}/cert_chain.crt', + private_key=f'{cert_path}/agg_{common_name}.key', + certificate=f'{key_path}/agg_{common_name}.crt').serve() else: plan.get_server().serve() @@ -79,12 +80,14 @@ def start_(plan, authorized_cols, secure, cert_path): f' aggregator node [{getfqdn_env()}]', default=getfqdn_env()) @option('-c', '--cert_path', - help='The cert path where pki certs will reside', required=False) -def _generate_cert_request(fqdn, cert_path): - generate_cert_request(fqdn, cert_path) + help='The path where aggregator certificate will reside', required=False) +@option('-k', '--key_path', + help='The path where aggregator key will reside', required=False) +def _generate_cert_request(fqdn, cert_path, key_path): + generate_cert_request(fqdn, cert_path, key_path) -def generate_cert_request(fqdn, cert_path=None): +def generate_cert_request(fqdn, cert_path=None, key_path=None): """Create aggregator certificate key pair.""" from pathlib import Path from openfl.cryptography.participant import generate_csr @@ -105,21 +108,29 @@ def generate_cert_request(fqdn, cert_path=None): server_private_key, server_csr = generate_csr(common_name, server=True) - if cert_path: + if cert_path and key_path: cert_path = Path(cert_path).absolute() - (cert_path / 'cert').mkdir(parents=True, exist_ok=True) - cert_dir_path = cert_path / 'cert' - else: - cert_dir_path = CERT_DIR + key_path = Path(key_path).absolute() + + echo(' Writing AGGREGATOR certificate to: ' + style( + f'{cert_path}', fg='green')) + echo(' Writing AGGREGATOR key to: ' + style( + f'{key_path}', fg='green')) - (cert_dir_path / 'server').mkdir(parents=True, exist_ok=True) + # Write aggregator csr and key to disk + write_crt(server_csr, cert_path / f'{file_name}.csr') + write_key(server_private_key, key_path / f'{file_name}.key') + else: + if cert_path and not key_path or not cert_path and key_path: + echo(f'Both cert_path and key_path should be provided. Using default {CERT_DIR}.') + (CERT_DIR / 'server').mkdir(parents=True, exist_ok=True) - echo(' Writing AGGREGATOR certificate key pair to: ' + style( - f'{cert_dir_path}/server', fg='green')) + echo(' Writing AGGREGATOR certificate key pair to: ' + style( + f'{CERT_DIR}/server', fg='green')) - # Write aggregator csr and key to disk - write_crt(server_csr, cert_dir_path / 'server' / f'{file_name}.csr') - write_key(server_private_key, cert_dir_path / 'server' / f'{file_name}.key') + # Write aggregator csr and key to disk + write_crt(server_csr, CERT_DIR / 'server' / f'{file_name}.csr') + write_key(server_private_key, CERT_DIR / 'server' / f'{file_name}.key') # TODO: function not used @@ -140,12 +151,14 @@ def find_certificate_name(file_name): default=getfqdn_env()) @option('-s', '--silent', help='Do not prompt', is_flag=True) @option('-c', '--cert_path', - help='The cert path where pki certs will reside', required=False) -def _certify(fqdn, silent, cert_path): - certify(fqdn, silent, cert_path) + help='The path where signing CA certificate resides', required=False) +@option('-k', '--key_path', + help='The path where signing CA key resides', required=False) +def _certify(fqdn, silent, cert_path, key_path): + certify(fqdn, silent, cert_path, key_path) -def certify(fqdn, silent, cert_path=None): +def certify(fqdn, silent, cert_path=None, key_path=None): """Sign/certify the aggregator certificate key pair.""" from pathlib import Path @@ -163,82 +176,140 @@ def certify(fqdn, silent, cert_path=None): common_name = f'{fqdn}'.lower() file_name = f'agg_{common_name}' - cert_name = f'server/{file_name}' - signing_key_path = 'ca/signing-ca/private/signing-ca.key' - signing_crt_path = 'ca/signing-ca.crt' # Load CSR - if cert_path: + if cert_path and key_path: cert_path = Path(cert_path).absolute() - (cert_path / 'cert').mkdir(parents=True, exist_ok=True) - cert_dir_path = cert_path / 'cert' - else: - cert_dir_path = CERT_DIR + key_path = Path(key_path).absolute() + + agg_cert_name = f'{file_name}' + csr_path_absolute_path = Path(cert_path / f'{agg_cert_name}.csr').absolute() + + if not csr_path_absolute_path.exists(): + echo(style('Aggregator certificate signing request not found.', fg='red') + + ' Please run `fx aggregator generate-cert-request -c -k`' + ' to generate the certificate request.') - csr_path_absolute_path = Path(cert_dir_path / f'{cert_name}.csr').absolute() - if not csr_path_absolute_path.exists(): - echo(style('Aggregator certificate signing request not found.', fg='red') - + ' Please run `fx aggregator generate-cert-request`' - ' to generate the certificate request.') + csr, csr_hash = read_csr(csr_path_absolute_path) - csr, csr_hash = read_csr(csr_path_absolute_path) + # Load private signing key + signing_key_path = 'signing-ca.key' + private_sign_key_absolute_path = Path(cert_path / signing_key_path).absolute() + if not private_sign_key_absolute_path.exists(): + echo(style('Signing key not found.', fg='red') + + ' Please run `fx workspace certify -c -k`' + ' to initialize the local certificate authority.') - # Load private signing key - private_sign_key_absolute_path = Path(cert_dir_path / signing_key_path).absolute() - if not private_sign_key_absolute_path.exists(): - echo(style('Signing key not found.', fg='red') - + ' Please run `fx workspace certify`' - ' to initialize the local certificate authority.') + signing_key = read_key(private_sign_key_absolute_path) - signing_key = read_key(private_sign_key_absolute_path) + # Load signing cert + signing_crt_path = 'signing-ca.crt' + signing_crt_absolute_path = Path(cert_path / signing_crt_path).absolute() + if not signing_crt_absolute_path.exists(): + echo(style('Signing certificate not found.', fg='red') + + ' Please run `fx workspace certify -c -k`' + ' to initialize the local certificate authority.') - # Load signing cert - signing_crt_absolute_path = Path(cert_dir_path / signing_crt_path).absolute() - if not signing_crt_absolute_path.exists(): - echo(style('Signing certificate not found.', fg='red') - + ' Please run `fx workspace certify`' - ' to initialize the local certificate authority.') + signing_crt = read_crt(signing_crt_absolute_path) - signing_crt = read_crt(signing_crt_absolute_path) + echo('The CSR Hash for file ' + + style(f'{agg_cert_name}.csr', fg='green') + + ' = ' + + style(f'{csr_hash}', fg='red')) - echo('The CSR Hash for file ' - + style(f'{cert_name}.csr', fg='green') - + ' = ' - + style(f'{csr_hash}', fg='red')) + crt_path_absolute_path = Path(cert_path / f'{agg_cert_name}.crt').absolute() - crt_path_absolute_path = Path(cert_dir_path / f'{cert_name}.crt').absolute() + if silent: + + echo(' Signing AGGREGATOR certificate') + signed_agg_cert = sign_certificate(csr, signing_key, signing_crt.subject) + write_crt(signed_agg_cert, crt_path_absolute_path) + + else: - if silent: + if confirm('Do you want to sign this certificate?'): - echo(' Signing AGGREGATOR certificate') - signed_agg_cert = sign_certificate(csr, signing_key, signing_crt.subject) - write_crt(signed_agg_cert, crt_path_absolute_path) + echo(' Signing AGGREGATOR certificate') + signed_agg_cert = sign_certificate(csr, signing_key, signing_crt.subject) + write_crt(signed_agg_cert, crt_path_absolute_path) + + else: + echo(style('Not signing certificate.', fg='red') + + ' Please check with this AGGREGATOR to get the correct' + ' certificate for this federation.') else: + agg_cert_name = f'server/{file_name}' + signing_key_path = 'ca/signing-ca/private/signing-ca.key' + signing_crt_path = 'ca/signing-ca.crt' + csr_path_absolute_path = Path(CERT_DIR / f'{agg_cert_name}.csr').absolute() + if not csr_path_absolute_path.exists(): + echo(style('Aggregator certificate signing request not found.', fg='red') + + ' Please run `fx aggregator generate-cert-request`' + ' to generate the certificate request.') + + csr, csr_hash = read_csr(csr_path_absolute_path) + + # Load private signing key + private_sign_key_absolute_path = Path(CERT_DIR / signing_key_path).absolute() + if not private_sign_key_absolute_path.exists(): + echo(style('Signing key not found.', fg='red') + + ' Please run `fx workspace certify`' + ' to initialize the local certificate authority.') + + signing_key = read_key(private_sign_key_absolute_path) + + # Load signing cert + signing_crt_absolute_path = Path(CERT_DIR / signing_crt_path).absolute() + if not signing_crt_absolute_path.exists(): + echo(style('Signing certificate not found.', fg='red') + + ' Please run `fx workspace certify`' + ' to initialize the local certificate authority.') + + signing_crt = read_crt(signing_crt_absolute_path) - if confirm('Do you want to sign this certificate?'): + echo('The CSR Hash for file ' + + style(f'{agg_cert_name}.csr', fg='green') + + ' = ' + + style(f'{csr_hash}', fg='red')) + + crt_path_absolute_path = Path(CERT_DIR / f'{agg_cert_name}.crt').absolute() + + if silent: echo(' Signing AGGREGATOR certificate') signed_agg_cert = sign_certificate(csr, signing_key, signing_crt.subject) write_crt(signed_agg_cert, crt_path_absolute_path) else: - echo(style('Not signing certificate.', fg='red') - + ' Please check with this AGGREGATOR to get the correct' - ' certificate for this federation.') + + if confirm('Do you want to sign this certificate?'): + + echo(' Signing AGGREGATOR certificate') + signed_agg_cert = sign_certificate(csr, signing_key, signing_crt.subject) + write_crt(signed_agg_cert, crt_path_absolute_path) + + else: + echo(style('Not signing certificate.', fg='red') + + ' Please check with this AGGREGATOR to get the correct' + ' certificate for this federation.') @aggregator.command(name='uninstall-cert') @option('-c', '--cert_path', help='The cert path where pki certs reside', required=True) -def _uninstall_cert(cert_path): - uninstall_cert(cert_path) +@option('-k', '--key_path', + help='The key path where key reside', required=True) +def _uninstall_cert(cert_path, key_path): + uninstall_cert(cert_path, key_path) -def uninstall_cert(cert_path=None): +def uninstall_cert(cert_path=None, key_path=None): """Uninstall certs under a given directory.""" - import shutil + from openfl.utilities.utils import rmtree from pathlib import Path cert_path = Path(cert_path).absolute() - shutil.rmtree(cert_path, ignore_errors=True) + rmtree(cert_path, ignore_errors=True) + key_path = Path(key_path).absolute() + rmtree(key_path, ignore_errors=True) diff --git a/openfl/interface/collaborator.py b/openfl/interface/collaborator.py index 92ee1c23924..9dc62a24f55 100644 --- a/openfl/interface/collaborator.py +++ b/openfl/interface/collaborator.py @@ -38,8 +38,10 @@ def collaborator(context): @option('-s', '--secure', required=False, help='Enable Intel SGX Enclave', is_flag=True, default=False) @option('-c', '--cert_path', - help='The cert path where pki certs will reside', required=False) -def start_(plan, collaborator_name, data_config, secure, cert_path): + help='The path where collaborator certificate resides', required=False) +@option('-k', '--key_path', + help='The path where collaborator key resides', required=False) +def start_(plan, collaborator_name, data_config, secure, cert_path, key_path): """Start a collaborator service.""" from pathlib import Path @@ -60,19 +62,18 @@ def start_(plan, collaborator_name, data_config, secure, cert_path): echo(f'Data = {plan.cols_data_paths}') logger.info('🧿 Starting a Collaborator Service.') - if cert_path: + if cert_path and key_path: cert_path = Path(cert_path).absolute() - (cert_path / 'cert').mkdir(parents=True, exist_ok=True) - cert_dir_path = cert_path / 'cert' - if not Path(cert_dir_path).exists(): + key_path = Path(key_path).absolute() + if not Path(cert_path).exists() or not Path(key_path).exists(): echo(style('Certificate Path not found.', fg='red') - + ' Please run `fx collaborator generate-cert-request --cert_path`' + + ' Please run `fx collaborator generate-cert-request -c -k`' ' to generate certs under this directory first.') common_name = f'{collaborator_name}'.lower() plan.get_collaborator(collaborator_name, - root_certificate=f'{cert_dir_path}/cert_chain.crt', - private_key=f'{cert_dir_path}/client/col_{common_name}.key', - certificate=f'{cert_dir_path}/client/col_{common_name}.crt').run() + root_certificate=f'{cert_path}/cert_chain.crt', + private_key=f'{cert_path}/col_{common_name}.key', + certificate=f'{cert_path}/col_{common_name}.crt').run() else: plan.get_collaborator(collaborator_name).run() @@ -134,17 +135,20 @@ def register_data_path(collaborator_name, data_path=None, silent=False): help='Do not package the certificate signing request for export', is_flag=True) @option('-c', '--cert_path', - help='The cert path where pki certs will reside', required=False) + help='The path where collaborator certificate resides', required=False) +@option('-k', '--key_path', + help='The path where collaborator key resides', required=False) def generate_cert_request_(collaborator_name, - data_path, silent, skip_package, cert_path): + data_path, silent, skip_package, cert_path, key_path): """Generate certificate request for the collaborator.""" if data_path and is_directory_traversal(data_path): echo('Data path is out of the openfl workspace scope.') sys.exit(1) - generate_cert_request(collaborator_name, data_path, silent, skip_package, cert_path) + generate_cert_request(collaborator_name, data_path, silent, skip_package, cert_path, key_path) -def generate_cert_request(collaborator_name, data_path, silent, skip_package, cert_path=None): +def generate_cert_request(collaborator_name, data_path, silent, skip_package, + cert_path=None, key_path=None): """ Create collaborator certificate key pair. @@ -166,21 +170,27 @@ def generate_cert_request(collaborator_name, data_path, silent, skip_package, ce client_private_key, client_csr = generate_csr(common_name, server=False) - if cert_path: + if cert_path and key_path: cert_path = Path(cert_path).absolute() - (cert_path / 'cert').mkdir(parents=True, exist_ok=True) - cert_dir_path = cert_path / 'cert' - else: - cert_dir_path = CERT_DIR + key_path = Path(key_path).absolute() + + echo(' Moving COLLABORATOR certificate to: ' + style( + f'{cert_path}', fg='green')) + echo(' Moving COLLABORATOR key to: ' + style( + f'{key_path}', fg='green')) - (cert_dir_path / 'client').mkdir(parents=True, exist_ok=True) + # Write collaborator csr and key to disk + write_crt(client_csr, cert_path / f'{file_name}.csr') + write_key(client_private_key, key_path / f'{file_name}.key') + else: + (CERT_DIR / 'client').mkdir(parents=True, exist_ok=True) - echo(' Moving COLLABORATOR certificate to: ' + style( - f'{cert_dir_path}', fg='green')) + echo(' Moving COLLABORATOR certificate key pair to: ' + style( + f'{CERT_DIR}', fg='green')) - # Write collaborator csr and key to disk - write_crt(client_csr, cert_dir_path / 'client' / f'{file_name}.csr') - write_key(client_private_key, cert_dir_path / 'client' / f'{file_name}.key') + # Write collaborator csr and key to disk + write_crt(client_csr, CERT_DIR / 'client' / f'{file_name}.csr') + write_key(client_private_key, CERT_DIR / 'client' / f'{file_name}.key') if not skip_package: from shutil import copytree @@ -201,7 +211,11 @@ def generate_cert_request(collaborator_name, data_path, silent, skip_package, ce ignore = ignore_patterns('__pycache__', '*.key', '*.srl', '*.pem') # Copy the current directory into the temporary directory - copytree(f'{cert_dir_path}/client', tmp_dir, ignore=ignore) + if cert_path: + cert_path = Path(cert_path).absolute() + copytree(cert_path, tmp_dir, ignore=ignore) + else: + copytree(f'{CERT_DIR}/client', tmp_dir, ignore=ignore) for f in glob(f'{tmp_dir}/*'): if common_name not in basename(f): @@ -285,13 +299,16 @@ def register_collaborator(file_name): help='Import the archive containing the collaborator\'s' ' certificate (signed by the CA)') @option('-c', '--cert_path', - help='The cert path where pki certs will reside', required=False) -def certify_(collaborator_name, silent, request_pkg, import_, cert_path): + help='The path where signing CA certificate resides', required=False) +@option('-k', '--key_path', + help='The path where signing CA key resides', required=False) +def certify_(collaborator_name, silent, request_pkg, import_, cert_path, key_path): """Certify the collaborator.""" - certify(collaborator_name, silent, request_pkg, import_, cert_path) + certify(collaborator_name, silent, request_pkg, import_, cert_path, key_path) -def certify(collaborator_name, silent, request_pkg=None, import_=False, cert_path=None): +def certify(collaborator_name, silent, request_pkg=None, import_=False, + cert_path=None, key_path=None): """Sign/certify collaborator certificate key pair.""" from click import confirm from pathlib import Path @@ -313,131 +330,241 @@ def certify(collaborator_name, silent, request_pkg=None, import_=False, cert_pat common_name = f'{collaborator_name}'.lower() - if cert_path: + if cert_path and key_path: cert_path = Path(cert_path).absolute() - (cert_path / 'cert').mkdir(parents=True, exist_ok=True) - cert_dir_path = cert_path / 'cert' - else: - cert_dir_path = CERT_DIR + key_path = Path(key_path).absolute() - if not import_: - if request_pkg: - Path(f'{cert_dir_path}/client').mkdir(parents=True, exist_ok=True) - unpack_archive(request_pkg, extract_dir=f'{cert_dir_path}/client') - csr = glob(f'{cert_dir_path}/client/*.csr')[0] - else: - if collaborator_name is None: - echo('collaborator_name can only be omitted if signing\n' - 'a zipped request package.\n' - '\n' - 'Example: fx collaborator certify --request-pkg ' - 'col_one_to_agg_cert_request.zip') - return - csr = glob(f'{cert_dir_path}/client/col_{common_name}.csr')[0] - copy(csr, cert_dir_path) - cert_name = splitext(csr)[0] - file_name = basename(cert_name) - signing_key_path = 'ca/signing-ca/private/signing-ca.key' - signing_crt_path = 'ca/signing-ca.crt' + if not import_: + if request_pkg: + unpack_archive(request_pkg, extract_dir=cert_path) + csr = glob(f'{cert_path}/*.csr')[0] + print("csr is", csr) + else: + if collaborator_name is None: + echo('collaborator_name can only be omitted if signing\n' + 'a zipped request package.\n' + '\n' + 'Example: fx collaborator certify --request-pkg ' + 'col_one_to_agg_cert_request.zip') + return + csr = glob(f'{cert_path}/col_{common_name}.csr')[0] + print("Csr is", csr) + copy(csr, cert_path) + cert_name = splitext(csr)[0] + print("cert name", cert_name) + file_name = basename(cert_name) + print("file name", file_name) + signing_key_path = 'signing-ca.key' + signing_crt_path = 'signing-ca.crt' + + # Load CSR + if not Path(f'{cert_name}.csr').exists(): + echo(style('Collaborator certificate signing request not found.', fg='red') + + ' Please run `fx collaborator generate-cert-request -c -k`' + ' to generate the certificate request.') + + csr, csr_hash = read_csr(f'{cert_name}.csr') + + # Load private signing key + if not Path(key_path / signing_key_path).exists(): + echo(style('Signing key not found.', fg='red') + + ' Please run `fx workspace certify`' + ' to initialize the local certificate authority.') + + signing_key = read_key(key_path / signing_key_path) + + # Load signing cert + if not Path(cert_path / signing_crt_path).exists(): + echo(style('Signing certificate not found.', fg='red') + + ' Please run `fx workspace certify`' + ' to initialize the local certificate authority.') + + signing_crt = read_crt(cert_path / signing_crt_path) + + echo('The CSR Hash for file ' + + style(f'{file_name}.csr', fg='green') + + ' = ' + + style(f'{csr_hash}', fg='red')) + + if silent: - # Load CSR - if not Path(f'{cert_name}.csr').exists(): - echo(style('Collaborator certificate signing request not found.', fg='red') - + ' Please run `fx collaborator generate-cert-request`' - ' to generate the certificate request.') + echo(' Signing COLLABORATOR certificate') + signed_col_cert = sign_certificate(csr, signing_key, signing_crt.subject) + write_crt(signed_col_cert, f'{cert_name}.crt') + register_collaborator(cert_path / f'{file_name}.crt') - csr, csr_hash = read_csr(f'{cert_name}.csr') + else: - # Load private signing key - if not Path(cert_dir_path / signing_key_path).exists(): - echo(style('Signing key not found.', fg='red') - + ' Please run `fx workspace certify`' - ' to initialize the local certificate authority.') + if confirm('Do you want to sign this certificate?'): - signing_key = read_key(cert_dir_path / signing_key_path) + echo(' Signing COLLABORATOR certificate') + signed_col_cert = sign_certificate(csr, signing_key, signing_crt.subject) + write_crt(signed_col_cert, f'{cert_name}.crt') + register_collaborator(cert_path / f'{file_name}.crt') - # Load signing cert - if not Path(cert_dir_path / signing_crt_path).exists(): - echo(style('Signing certificate not found.', fg='red') - + ' Please run `fx workspace certify`' - ' to initialize the local certificate authority.') + else: + echo(style('Not signing certificate.', fg='red') + + ' Please check with this collaborator to get the' + ' correct certificate for this federation.') + return - signing_crt = read_crt(cert_dir_path / signing_crt_path) + if len(common_name) == 0: + # If the collaborator name is provided, the collaborator and + # certificate does not need to be exported + return - echo('The CSR Hash for file ' - + style(f'{file_name}.csr', fg='green') - + ' = ' - + style(f'{csr_hash}', fg='red')) + # Remove unneeded CSR + remove(f'{cert_name}.csr') - if silent: + archive_type = 'zip' + archive_name = f'agg_to_{file_name}_signed_cert' - echo(' Signing COLLABORATOR certificate') - signed_col_cert = sign_certificate(csr, signing_key, signing_crt.subject) - write_crt(signed_col_cert, f'{cert_name}.crt') - register_collaborator(cert_dir_path / 'client' / f'{file_name}.crt') + # Collaborator certificate signing request + tmp_dir = join(mkdtemp(), 'openfl', archive_name) - else: + Path(tmp_dir).mkdir(parents=True, exist_ok=True) + # Copy the signed cert to the temporary directory + copy(f'{cert_path}/{file_name}.crt', f'{tmp_dir}/') + # Copy the CA certificate chain to the temporary directory + copy(f'{cert_path}/cert_chain.crt', tmp_dir) - if confirm('Do you want to sign this certificate?'): + # Create Zip archive of directory + make_archive(archive_name, archive_type, tmp_dir) + + else: + # Copy the signed certificate and cert chain into PKI_DIR + previous_crts = glob(f'{cert_path}/*.crt') + unpack_archive(import_, extract_dir=cert_path) + updated_crts = glob(f'{cert_path}/*.crt') + cert_difference = list(set(updated_crts) - set(previous_crts)) + if len(cert_difference) != 0: + crt = basename(cert_difference[0]) + echo(f'Certificate {crt} installed to PKI directory') + else: + echo('Certificate updated in the PKI directory') + else: + if not import_: + if request_pkg: + Path(f'{CERT_DIR}/client').mkdir(parents=True, exist_ok=True) + unpack_archive(request_pkg, extract_dir=f'{CERT_DIR}/client') + csr = glob(f'{CERT_DIR}/client/*.csr')[0] + else: + if collaborator_name is None: + echo('collaborator_name can only be omitted if signing\n' + 'a zipped request package.\n' + '\n' + 'Example: fx collaborator certify --request-pkg ' + 'col_one_to_agg_cert_request.zip') + return + csr = glob(f'{CERT_DIR}/client/col_{common_name}.csr')[0] + copy(csr, CERT_DIR) + cert_name = splitext(csr)[0] + file_name = basename(cert_name) + signing_key_path = 'ca/signing-ca/private/signing-ca.key' + signing_crt_path = 'ca/signing-ca.crt' + + # Load CSR + if not Path(f'{cert_name}.csr').exists(): + echo(style('Collaborator certificate signing request not found.', fg='red') + + ' Please run `fx collaborator generate-cert-request`' + ' to generate the certificate request.') + + csr, csr_hash = read_csr(f'{cert_name}.csr') + + # Load private signing key + if not Path(CERT_DIR / signing_key_path).exists(): + echo(style('Signing key not found.', fg='red') + + ' Please run `fx workspace certify`' + ' to initialize the local certificate authority.') + + signing_key = read_key(CERT_DIR / signing_key_path) + + # Load signing cert + if not Path(CERT_DIR / signing_crt_path).exists(): + echo(style('Signing certificate not found.', fg='red') + + ' Please run `fx workspace certify`' + ' to initialize the local certificate authority.') + + signing_crt = read_crt(CERT_DIR / signing_crt_path) + + echo('The CSR Hash for file ' + + style(f'{file_name}.csr', fg='green') + + ' = ' + + style(f'{csr_hash}', fg='red')) + + if silent: echo(' Signing COLLABORATOR certificate') signed_col_cert = sign_certificate(csr, signing_key, signing_crt.subject) write_crt(signed_col_cert, f'{cert_name}.crt') - register_collaborator(cert_dir_path / 'client' / f'{file_name}.crt') + register_collaborator(CERT_DIR / 'client' / f'{file_name}.crt') else: - echo(style('Not signing certificate.', fg='red') - + ' Please check with this collaborator to get the' - ' correct certificate for this federation.') - return - if len(common_name) == 0: - # If the collaborator name is provided, the collaborator and - # certificate does not need to be exported - return + if confirm('Do you want to sign this certificate?'): - # Remove unneeded CSR - remove(f'{cert_name}.csr') + echo(' Signing COLLABORATOR certificate') + signed_col_cert = sign_certificate(csr, signing_key, signing_crt.subject) + write_crt(signed_col_cert, f'{cert_name}.crt') + register_collaborator(CERT_DIR / 'client' / f'{file_name}.crt') - archive_type = 'zip' - archive_name = f'agg_to_{file_name}_signed_cert' + else: + echo(style('Not signing certificate.', fg='red') + + ' Please check with this collaborator to get the' + ' correct certificate for this federation.') + return - # Collaborator certificate signing request - tmp_dir = join(mkdtemp(), 'openfl', archive_name) + if len(common_name) == 0: + # If the collaborator name is provided, the collaborator and + # certificate does not need to be exported + return - Path(f'{tmp_dir}/client').mkdir(parents=True, exist_ok=True) - # Copy the signed cert to the temporary directory - copy(f'{cert_dir_path}/client/{file_name}.crt', f'{tmp_dir}/client/') - # Copy the CA certificate chain to the temporary directory - copy(f'{cert_dir_path}/cert_chain.crt', tmp_dir) + # Remove unneeded CSR + remove(f'{cert_name}.csr') - # Create Zip archive of directory - make_archive(archive_name, archive_type, tmp_dir) + archive_type = 'zip' + archive_name = f'agg_to_{file_name}_signed_cert' + + # Collaborator certificate signing request + tmp_dir = join(mkdtemp(), 'openfl', archive_name) + + Path(f'{tmp_dir}/client').mkdir(parents=True, exist_ok=True) + # Copy the signed cert to the temporary directory + copy(f'{CERT_DIR}/client/{file_name}.crt', f'{tmp_dir}/client/') + # Copy the CA certificate chain to the temporary directory + copy(f'{CERT_DIR}/cert_chain.crt', tmp_dir) + + # Create Zip archive of directory + make_archive(archive_name, archive_type, tmp_dir) - else: - # Copy the signed certificate and cert chain into PKI_DIR - previous_crts = glob(f'{cert_dir_path}/client/*.crt') - unpack_archive(import_, extract_dir=cert_dir_path) - updated_crts = glob(f'{cert_dir_path}/client/*.crt') - cert_difference = list(set(updated_crts) - set(previous_crts)) - if len(cert_difference) != 0: - crt = basename(cert_difference[0]) - echo(f'Certificate {crt} installed to PKI directory') else: - echo('Certificate updated in the PKI directory') + # Copy the signed certificate and cert chain into PKI_DIR + previous_crts = glob(f'{CERT_DIR}/client/*.crt') + unpack_archive(import_, extract_dir=CERT_DIR) + updated_crts = glob(f'{CERT_DIR}/client/*.crt') + cert_difference = list(set(updated_crts) - set(previous_crts)) + if len(cert_difference) != 0: + crt = basename(cert_difference[0]) + echo(f'Certificate {crt} installed to PKI directory') + else: + echo('Certificate updated in the PKI directory') @collaborator.command(name='uninstall-cert') @option('-c', '--cert_path', help='The cert path where pki certs reside', required=True) -def _uninstall_cert(cert_path): - uninstall_cert(cert_path) +@option('-k', '--key_path', + help='The key path where key reside', required=True) +def _uninstall_cert(cert_path, key_path): + uninstall_cert(cert_path, key_path) -def uninstall_cert(cert_path=None): +def uninstall_cert(cert_path=None, key_path=None): """Uninstall certs under a given directory.""" - import shutil + from openfl.utilities.utils import rmtree from pathlib import Path cert_path = Path(cert_path).absolute() - shutil.rmtree(cert_path, ignore_errors=True) + rmtree(cert_path, ignore_errors=True) + key_path = Path(key_path).absolute() + rmtree(key_path, ignore_errors=True) diff --git a/openfl/interface/workspace.py b/openfl/interface/workspace.py index b728b57eb51..c6105412fbe 100644 --- a/openfl/interface/workspace.py +++ b/openfl/interface/workspace.py @@ -213,14 +213,18 @@ def import_(archive): @workspace.command(name='certify') +@option('-cdir', '--cert_dir', + help='The cert directory path where CA certs and keys will reside', required=False) @option('-c', '--cert_path', - help='The cert path where pki certs will reside', required=False) -def certify_(cert_path): + help='The cert path where CA signing cert will reside', required=False) +@option('-k', '--key_path', + help='The cert path where CA key path will reside', required=False) +def certify_(cert_dir, cert_path, key_path): """Create certificate authority for federation.""" - certify(cert_path) + certify(cert_dir, cert_path, key_path) -def certify(cert_path=None): +def certify(cert_dir=None, cert_path=None, key_path=None): """Create certificate authority for federation.""" from cryptography.hazmat.primitives import serialization @@ -234,10 +238,10 @@ def certify(cert_path=None): echo('1. Create Root CA') echo('1.1 Create Directories') - if cert_path: - cert_path = Path(cert_path).absolute() - (cert_path / 'cert').mkdir(parents=True, exist_ok=True) - cert_dir_path = cert_path / 'cert' + if cert_dir: + cert_dir = Path(cert_dir).absolute() + (cert_dir / 'cert').mkdir(parents=True, exist_ok=True) + cert_dir_path = cert_dir / 'cert' else: cert_dir_path = CERT_DIR @@ -298,42 +302,67 @@ def certify(cert_path=None): echo('2.3 Create Signing Certificate CSR') - signing_csr_path = 'ca/signing-ca.csr' - signing_crt_path = 'ca/signing-ca.crt' - signing_key_path = 'ca/signing-ca/private/signing-ca.key' - signing_private_key, signing_csr = generate_signing_csr() # Write Signing CA CSR to disk + signing_csr_path = 'ca/signing-ca.csr' with open(cert_dir_path / signing_csr_path, 'wb') as f: f.write(signing_csr.public_bytes( encoding=serialization.Encoding.PEM, )) - with open(cert_dir_path / signing_key_path, 'wb') as f: - f.write(signing_private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption() - )) + if key_path: + key_path = Path(key_path).absolute() + signing_key_path = 'signing-ca.key' + with open(key_path / signing_key_path, 'wb') as f: + f.write(signing_private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption() + )) + else: + signing_key_path = 'ca/signing-ca/private/signing-ca.key' + with open(cert_dir_path / signing_key_path, 'wb') as f: + f.write(signing_private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption() + )) echo('2.4 Sign Signing Certificate CSR') signing_cert = sign_certificate(signing_csr, root_private_key, root_cert.subject, ca=True) - with open(cert_dir_path / signing_crt_path, 'wb') as f: - f.write(signing_cert.public_bytes( - encoding=serialization.Encoding.PEM, - )) + if cert_path: + cert_path = Path(cert_path).absolute() + signing_crt_path = 'signing-ca.crt' + with open(cert_path / signing_crt_path, 'wb') as f: + f.write(signing_cert.public_bytes( + encoding=serialization.Encoding.PEM, + )) + else: + signing_crt_path = 'ca/signing-ca.crt' + with open(cert_dir_path / signing_crt_path, 'wb') as f: + f.write(signing_cert.public_bytes( + encoding=serialization.Encoding.PEM, + )) echo('3 Create Certificate Chain') # create certificate chain file by combining root-ca and signing-ca - with open(cert_dir_path / 'cert_chain.crt', 'w', encoding='utf-8') as d: - with open(cert_dir_path / 'ca/root-ca.crt', encoding='utf-8') as s: - d.write(s.read()) - with open(cert_dir_path / 'ca/signing-ca.crt') as s: - d.write(s.read()) + if cert_path: + cert_chain_path = Path(cert_path).absolute() + with open(cert_chain_path / 'cert_chain.crt', 'w', encoding='utf-8') as d: + with open(cert_dir_path / 'ca/root-ca.crt', encoding='utf-8') as s: + d.write(s.read()) + with open(cert_chain_path / 'signing-ca.crt') as s: + d.write(s.read()) + else: + with open(cert_dir_path / 'cert_chain.crt', 'w', encoding='utf-8') as d: + with open(cert_dir_path / 'ca/root-ca.crt', encoding='utf-8') as s: + d.write(s.read()) + with open(cert_dir_path / 'ca/signing-ca.crt') as s: + d.write(s.read()) echo('\nDone.') @@ -538,17 +567,21 @@ def open_pipe(command: str): @workspace.command(name='uninstall-cert') @option('-c', '--cert_path', help='The cert path where pki certs reside', required=True) -def _uninstall_cert(cert_path): - uninstall_cert(cert_path) +@option('-k', '--key_path', + help='The key path where key reside', required=True) +def _uninstall_cert(cert_path, key_path): + uninstall_cert(cert_path, key_path) -def uninstall_cert(cert_path=None): +def uninstall_cert(cert_path=None, key_path=None): """Uninstall certs under a given directory.""" - import shutil + from openfl.utilities.utils import rmtree from pathlib import Path cert_path = Path(cert_path).absolute() - shutil.rmtree(cert_path, ignore_errors=True) + rmtree(cert_path, ignore_errors=True) + key_path = Path(key_path).absolute() + rmtree(key_path, ignore_errors=True) def apply_template_plan(prefix, template): diff --git a/tests/github/test_pki_cert_location.py b/tests/github/test_pki_cert_location.py index cfd7b72d3e9..1af0a59c740 100644 --- a/tests/github/test_pki_cert_location.py +++ b/tests/github/test_pki_cert_location.py @@ -29,8 +29,11 @@ parser.add_argument('--col1-data-path', default='1') parser.add_argument('--col2-data-path', default='2') parser.add_argument('--agg-cert-path', default=Path.cwd()) + parser.add_argument('--agg-key-path', default=Path.cwd()) parser.add_argument('--col1-cert-path', default=Path.cwd()) + parser.add_argument('--col1-key-path', default=Path.cwd()) parser.add_argument('--col2-cert-path', default=Path.cwd()) + parser.add_argument('--col2-key-path', default=Path.cwd()) parser.add_argument('--save-model') origin_dir = Path.cwd().resolve() @@ -42,42 +45,57 @@ rounds_to_train = args.rounds_to_train col1, col2 = args.col1, args.col2 col1_data_path, col2_data_path = args.col1_data_path, args.col2_data_path + + os.makedirs(args.agg_cert_path, exist_ok=True) + os.makedirs(args.agg_key_path, exist_ok=True) + os.makedirs(args.col1_cert_path, exist_ok=True) + os.makedirs(args.col1_key_path, exist_ok=True) + os.makedirs(args.col2_cert_path, exist_ok=True) + os.makedirs(args.col2_key_path, exist_ok=True) + + ca_dir_path = Path(args.agg_cert_path).resolve() agg_cert_path = Path(args.agg_cert_path).resolve() + agg_key_path = Path(args.agg_key_path).resolve() col1_cert_path = Path(args.col1_cert_path).resolve() + col1_key_path = Path(args.col1_key_path).resolve() col2_cert_path = Path(args.col2_cert_path).resolve() + col2_key_path = Path(args.col2_key_path).resolve() save_model = args.save_model # START # ===== # Make sure you are in a Python virtual environment with the FL package installed. - create_certified_workspace(fed_workspace, template, fqdn, rounds_to_train, agg_cert_path) - certify_aggregator(fqdn, agg_cert_path) + create_certified_workspace(fed_workspace, template, fqdn, + rounds_to_train, ca_dir_path, agg_cert_path, agg_key_path) + certify_aggregator(fqdn, agg_cert_path, agg_key_path) workspace_root = Path().resolve() # Get the absolute directory path for the workspace # Create collaborator #1 create_collaborator(col1, workspace_root, col1_data_path, archive_name, - fed_workspace, col1_cert_path, agg_cert_path) + fed_workspace, col1_cert_path, col1_key_path, agg_cert_path, agg_key_path) # Create collaborator #2 create_collaborator(col2, workspace_root, col2_data_path, archive_name, - fed_workspace, col2_cert_path, agg_cert_path) + fed_workspace, col2_cert_path, col2_key_path, agg_cert_path, agg_key_path) # Run the federation with ProcessPoolExecutor(max_workers=3) as executor: executor.submit( - check_call, ['fx', 'aggregator', 'start', '-c', agg_cert_path], + check_call, ['fx', 'aggregator', 'start', '-c', agg_cert_path, '-k', agg_key_path], cwd=workspace_root) time.sleep(5) dir1 = workspace_root / col1 / fed_workspace executor.submit( - check_call, ['fx', 'collaborator', 'start', '-n', col1, '-c', col1_cert_path], + check_call, ['fx', 'collaborator', 'start', '-n', col1, + '-c', col1_cert_path, '-k', col1_key_path], cwd=dir1) dir2 = workspace_root / col2 / fed_workspace executor.submit( - check_call, ['fx', 'collaborator', 'start', '-n', col2, '-c', col2_cert_path], + check_call, ['fx', 'collaborator', 'start', '-n', col2, + '-c', col2_cert_path, '-k', col2_key_path], cwd=dir2) # Convert model to native format diff --git a/tests/github/utils.py b/tests/github/utils.py index f6a5a8be0e7..7c855840265 100644 --- a/tests/github/utils.py +++ b/tests/github/utils.py @@ -9,7 +9,7 @@ def create_collaborator(col, workspace_root, data_path, archive_name, fed_workspace, - cert_path=None, ca_cert_path=None): + cert_path=None, key_path=None, ca_cert_path=None, ca_key_path=None): # Copy workspace to collaborator directories (these can be on different machines) col_path = workspace_root / col shutil.rmtree(col_path, ignore_errors=True) # Remove any existing directory @@ -23,10 +23,10 @@ def create_collaborator(col, workspace_root, data_path, archive_name, fed_worksp # Create collaborator certificate request # Remove '--silent' if you run this manually - if cert_path: + if cert_path and key_path: check_call( ['fx', 'collaborator', 'generate-cert-request', '-d', data_path, - '-n', col, '-c', cert_path, '--silent'], + '-n', col, '-c', cert_path, '-k', key_path, '--silent'], cwd=col_path / fed_workspace ) else: @@ -39,10 +39,10 @@ def create_collaborator(col, workspace_root, data_path, archive_name, fed_worksp # Sign collaborator certificate # Remove '--silent' if you run this manually request_pkg = col_path / fed_workspace / f'col_{col}_to_agg_cert_request.zip' - if ca_cert_path: + if ca_cert_path and ca_key_path: check_call( ['fx', 'collaborator', 'certify', '--request-pkg', str(request_pkg), - '-c', ca_cert_path, '--silent'], + '-c', ca_cert_path, '-k', ca_key_path, '--silent'], cwd=workspace_root) else: check_call( @@ -51,9 +51,10 @@ def create_collaborator(col, workspace_root, data_path, archive_name, fed_worksp # Import the signed certificate from the aggregator import_path = workspace_root / f'agg_to_col_{col}_signed_cert.zip' - if cert_path: + if cert_path and key_path: check_call( - ['fx', 'collaborator', 'certify', '--import', import_path, '-c', cert_path], + ['fx', 'collaborator', 'certify', '--import', import_path, + '-c', cert_path, '-k', key_path], cwd=col_path / fed_workspace ) else: @@ -63,7 +64,8 @@ def create_collaborator(col, workspace_root, data_path, archive_name, fed_worksp ) -def create_certified_workspace(path, template, fqdn, rounds_to_train, cert_path=None): +def create_certified_workspace(path, template, fqdn, rounds_to_train, + cert_dir=None, cert_path=None, key_path=None): shutil.rmtree(path, ignore_errors=True) check_call(['fx', 'workspace', 'create', '--prefix', path, '--template', template]) os.chdir(path) @@ -83,8 +85,9 @@ def create_certified_workspace(path, template, fqdn, rounds_to_train, cert_path= except (ValueError, TypeError): pass # Create certificate authority for workspace - if cert_path: - check_call(['fx', 'workspace', 'certify', '-c', cert_path]) + if cert_dir and cert_path and key_path: + check_call(['fx', 'workspace', 'certify', + '-cdir', cert_dir, '-c', cert_path, '-k', key_path]) else: check_call(['fx', 'workspace', 'certify']) @@ -92,12 +95,14 @@ def create_certified_workspace(path, template, fqdn, rounds_to_train, cert_path= check_call(['fx', 'workspace', 'export']) -def certify_aggregator(fqdn, cert_path=None): - if cert_path: +def certify_aggregator(fqdn, cert_path=None, key_path=None): + if cert_path and key_path: # Create aggregator certificate - check_call(['fx', 'aggregator', 'generate-cert-request', '--fqdn', fqdn, '-c', cert_path]) + check_call(['fx', 'aggregator', 'generate-cert-request', '--fqdn', fqdn, + '-c', cert_path, '-k', key_path]) # Sign aggregator certificate - check_call(['fx', 'aggregator', 'certify', '--fqdn', fqdn, '-c', cert_path, '--silent']) + check_call(['fx', 'aggregator', 'certify', '--fqdn', fqdn, + '-c', cert_path, '-k', key_path, '--silent']) else: # Create aggregator certificate check_call(['fx', 'aggregator', 'generate-cert-request', '--fqdn', fqdn])