Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Mansi Sharma <[email protected]>
  • Loading branch information
Mansi Sharma committed Jan 4, 2023
1 parent b645aec commit 871aa83
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 250 deletions.
17 changes: 0 additions & 17 deletions .github/workflows/pki.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,3 @@ jobs:
run: |
python tests/github/pki_wrong_cn.py
test_pki_cert_location:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- name: Set up Python 3.8
uses: actions/setup-python@v3
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .
- name: Test PKI
run: |
bash tests/github/test_pki_cert_location.sh torch_cnn_mnist aggregator col1 col2 $(hostname --all-fqdns | awk '{print $1}') --rounds-to-train 3
60 changes: 30 additions & 30 deletions openfl/interface/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,14 @@ def aggregator(context):
help='Federated learning plan [plan/plan.yaml]',
default='plan/plan.yaml',
type=ClickPath(exists=True))
@option('-c', '--authorized_cols', required=False,
@option('-col', '--authorized_cols', required=False,
help='Authorized collaborator list [plan/cols.yaml]',
default='plan/cols.yaml', type=ClickPath(exists=True))
@option('-s', '--secure', required=False,
help='Enable Intel SGX Enclave', is_flag=True, default=False)
@option('-c', '--cert_path',
help='The cert path where pki certs will reside', required=False)
@option('--fqdn', required=False, type=click_types.FQDN,
help=f'The fully qualified domain name of'
f' aggregator node [{getfqdn_env()}]',
default=getfqdn_env())
def start_(plan, authorized_cols, secure, cert_path, fqdn):
def start_(plan, authorized_cols, secure, cert_path):
"""Start the aggregator service."""
from pathlib import Path

Expand All @@ -61,19 +57,18 @@ def start_(plan, authorized_cols, secure, cert_path, fqdn):
logger.info('🧿 Starting the Aggregator Service.')

if cert_path:
CERT_PATH = Path(cert_path).absolute()
(CERT_PATH / 'cert').mkdir(parents=True, exist_ok=True)
CERT_DIR = CERT_PATH / 'cert'
if not Path(CERT_DIR).exists():
cert_path = Path(cert_path).absolute()
(cert_path / 'cert').mkdir(parents=True, exist_ok=True)
cert_dir_path = cert_path / 'cert'
if not Path(cert_dir_path).exists():
echo(style('Certificate Path not found.', fg='red')
+ ' Please run `fx aggregator generate-cert-request --cert_path`'
' to generate certs under this directory first.')
if fqdn is None:
fqdn = getfqdn_env()
common_name = f'{fqdn}'.lower()
plan.get_server(root_certificate=f'{CERT_DIR}/cert_chain.crt',
private_key=f'{CERT_DIR}/server/agg_{common_name}.key',
certificate=f'{CERT_DIR}/server/agg_{common_name}.crt').serve()

common_name = plan.config['network']['settings']['agg_addr'].lower()
plan.get_server(root_certificate=f'{cert_dir_path}/cert_chain.crt',
private_key=f'{cert_dir_path}/server/agg_{common_name}.key',
certificate=f'{cert_dir_path}/server/agg_{common_name}.crt').serve()
else:
plan.get_server().serve()

Expand Down Expand Up @@ -111,17 +106,20 @@ def generate_cert_request(fqdn, cert_path=None):
server_private_key, server_csr = generate_csr(common_name, server=True)

if cert_path:
CERT_PATH = Path(cert_path).absolute()
(CERT_PATH / 'cert').mkdir(parents=True, exist_ok=True)
CERT_DIR = CERT_PATH/ 'cert' # NOQA
(CERT_DIR / 'server').mkdir(parents=True, exist_ok=True)
cert_path = Path(cert_path).absolute()
(cert_path / 'cert').mkdir(parents=True, exist_ok=True)
cert_dir_path = cert_path / 'cert'
else:
cert_dir_path = CERT_DIR

(cert_dir_path / 'server').mkdir(parents=True, exist_ok=True)

echo(' Writing AGGREGATOR certificate key pair to: ' + style(
f'{CERT_DIR}/server', fg='green'))
f'{cert_dir_path}/server', fg='green'))

# Write aggregator csr and key to disk
write_crt(server_csr, CERT_DIR / 'server' / f'{file_name}.csr')
write_key(server_private_key, CERT_DIR / 'server' / f'{file_name}.key')
write_crt(server_csr, cert_dir_path / 'server' / f'{file_name}.csr')
write_key(server_private_key, cert_dir_path / 'server' / f'{file_name}.key')


# TODO: function not used
Expand Down Expand Up @@ -171,11 +169,13 @@ def certify(fqdn, silent, cert_path=None):

# Load CSR
if cert_path:
CERT_PATH = Path(cert_path).absolute()
(CERT_PATH / 'cert').mkdir(parents=True, exist_ok=True)
CERT_DIR = CERT_PATH/ 'cert' # NOQA
cert_path = Path(cert_path).absolute()
(cert_path / 'cert').mkdir(parents=True, exist_ok=True)
cert_dir_path = cert_path / 'cert'
else:
cert_dir_path = CERT_DIR

csr_path_absolute_path = Path(CERT_DIR / f'{cert_name}.csr').absolute()
csr_path_absolute_path = Path(cert_dir_path / f'{cert_name}.csr').absolute()
if not csr_path_absolute_path.exists():
echo(style('Aggregator certificate signing request not found.', fg='red')
+ ' Please run `fx aggregator generate-cert-request`'
Expand All @@ -184,7 +184,7 @@ def certify(fqdn, silent, cert_path=None):
csr, csr_hash = read_csr(csr_path_absolute_path)

# Load private signing key
private_sign_key_absolute_path = Path(CERT_DIR / signing_key_path).absolute()
private_sign_key_absolute_path = Path(cert_dir_path / signing_key_path).absolute()
if not private_sign_key_absolute_path.exists():
echo(style('Signing key not found.', fg='red')
+ ' Please run `fx workspace certify`'
Expand All @@ -193,7 +193,7 @@ def certify(fqdn, silent, cert_path=None):
signing_key = read_key(private_sign_key_absolute_path)

# Load signing cert
signing_crt_absolute_path = Path(CERT_DIR / signing_crt_path).absolute()
signing_crt_absolute_path = Path(cert_dir_path / signing_crt_path).absolute()
if not signing_crt_absolute_path.exists():
echo(style('Signing certificate not found.', fg='red')
+ ' Please run `fx workspace certify`'
Expand All @@ -206,7 +206,7 @@ def certify(fqdn, silent, cert_path=None):
+ ' = '
+ style(f'{csr_hash}', fg='red'))

crt_path_absolute_path = Path(CERT_DIR / f'{cert_name}.crt').absolute()
crt_path_absolute_path = Path(cert_dir_path / f'{cert_name}.crt').absolute()

if silent:

Expand Down
72 changes: 38 additions & 34 deletions openfl/interface/collaborator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,18 @@ def start_(plan, collaborator_name, data_config, secure, cert_path):
logger.info('🧿 Starting a Collaborator Service.')

if cert_path:
CERT_PATH = Path(cert_path).absolute()
(CERT_PATH / 'cert').mkdir(parents=True, exist_ok=True)
CERT_DIR = CERT_PATH / 'cert'
if not Path(CERT_DIR).exists():
cert_path = Path(cert_path).absolute()
(cert_path / 'cert').mkdir(parents=True, exist_ok=True)
cert_dir_path = cert_path / 'cert'
if not Path(cert_dir_path).exists():
echo(style('Certificate Path not found.', fg='red')
+ ' Please run `fx collaborator generate-cert-request --cert_path`'
' to generate certs under this directory first.')
common_name = f'{collaborator_name}'.lower()
plan.get_collaborator(collaborator_name,
root_certificate=f'{CERT_DIR}/cert_chain.crt',
private_key=f'{CERT_DIR}/client/col_{common_name}.key',
certificate=f'{CERT_DIR}/client/col_{common_name}.crt').run()
root_certificate=f'{cert_dir_path}/cert_chain.crt',
private_key=f'{cert_dir_path}/client/col_{common_name}.key',
certificate=f'{cert_dir_path}/client/col_{common_name}.crt').run()
else:
plan.get_collaborator(collaborator_name).run()

Expand Down Expand Up @@ -167,18 +167,20 @@ def generate_cert_request(collaborator_name, data_path, silent, skip_package, ce
client_private_key, client_csr = generate_csr(common_name, server=False)

if cert_path:
CERT_PATH = Path(cert_path).absolute()
(CERT_PATH / 'cert').mkdir(parents=True, exist_ok=True)
CERT_DIR = CERT_PATH/ 'cert' # NOQA
cert_path = Path(cert_path).absolute()
(cert_path / 'cert').mkdir(parents=True, exist_ok=True)
cert_dir_path = cert_path / 'cert'
else:
cert_dir_path = CERT_DIR

(CERT_DIR / 'client').mkdir(parents=True, exist_ok=True)
(cert_dir_path / 'client').mkdir(parents=True, exist_ok=True)

echo(' Moving COLLABORATOR certificate to: ' + style(
f'{CERT_DIR}', fg='green'))
f'{cert_dir_path}', fg='green'))

# Write collaborator csr and key to disk
write_crt(client_csr, CERT_DIR / 'client' / f'{file_name}.csr')
write_key(client_private_key, CERT_DIR / 'client' / f'{file_name}.key')
write_crt(client_csr, cert_dir_path / 'client' / f'{file_name}.csr')
write_key(client_private_key, cert_dir_path / 'client' / f'{file_name}.key')

if not skip_package:
from shutil import copytree
Expand All @@ -199,7 +201,7 @@ def generate_cert_request(collaborator_name, data_path, silent, skip_package, ce

ignore = ignore_patterns('__pycache__', '*.key', '*.srl', '*.pem')
# Copy the current directory into the temporary directory
copytree(f'{CERT_DIR}/client', tmp_dir, ignore=ignore)
copytree(f'{cert_dir_path}/client', tmp_dir, ignore=ignore)

for f in glob(f'{tmp_dir}/*'):
if common_name not in basename(f):
Expand Down Expand Up @@ -312,15 +314,17 @@ def certify(collaborator_name, silent, request_pkg=None, import_=False, cert_pat
common_name = f'{collaborator_name}'.lower()

if cert_path:
CERT_PATH = Path(cert_path).absolute()
(CERT_PATH / 'cert').mkdir(parents=True, exist_ok=True)
CERT_DIR = CERT_PATH/ 'cert' # NOQA
cert_path = Path(cert_path).absolute()
(cert_path / 'cert').mkdir(parents=True, exist_ok=True)
cert_dir_path = cert_path / 'cert'
else:
cert_dir_path = CERT_DIR

if not import_:
if request_pkg:
Path(f'{CERT_DIR}/client').mkdir(parents=True, exist_ok=True)
unpack_archive(request_pkg, extract_dir=f'{CERT_DIR}/client')
csr = glob(f'{CERT_DIR}/client/*.csr')[0]
Path(f'{cert_dir_path}/client').mkdir(parents=True, exist_ok=True)
unpack_archive(request_pkg, extract_dir=f'{cert_dir_path}/client')
csr = glob(f'{cert_dir_path}/client/*.csr')[0]
else:
if collaborator_name is None:
echo('collaborator_name can only be omitted if signing\n'
Expand All @@ -329,8 +333,8 @@ def certify(collaborator_name, silent, request_pkg=None, import_=False, cert_pat
'Example: fx collaborator certify --request-pkg '
'col_one_to_agg_cert_request.zip')
return
csr = glob(f'{CERT_DIR}/client/col_{common_name}.csr')[0]
copy(csr, CERT_DIR)
csr = glob(f'{cert_dir_path}/client/col_{common_name}.csr')[0]
copy(csr, cert_dir_path)
cert_name = splitext(csr)[0]
file_name = basename(cert_name)
signing_key_path = 'ca/signing-ca/private/signing-ca.key'
Expand All @@ -345,20 +349,20 @@ def certify(collaborator_name, silent, request_pkg=None, import_=False, cert_pat
csr, csr_hash = read_csr(f'{cert_name}.csr')

# Load private signing key
if not Path(CERT_DIR / signing_key_path).exists():
if not Path(cert_dir_path / signing_key_path).exists():
echo(style('Signing key not found.', fg='red')
+ ' Please run `fx workspace certify`'
' to initialize the local certificate authority.')

signing_key = read_key(CERT_DIR / signing_key_path)
signing_key = read_key(cert_dir_path / signing_key_path)

# Load signing cert
if not Path(CERT_DIR / signing_crt_path).exists():
if not Path(cert_dir_path / signing_crt_path).exists():
echo(style('Signing certificate not found.', fg='red')
+ ' Please run `fx workspace certify`'
' to initialize the local certificate authority.')

signing_crt = read_crt(CERT_DIR / signing_crt_path)
signing_crt = read_crt(cert_dir_path / signing_crt_path)

echo('The CSR Hash for file '
+ style(f'{file_name}.csr', fg='green')
Expand All @@ -370,7 +374,7 @@ def certify(collaborator_name, silent, request_pkg=None, import_=False, cert_pat
echo(' Signing COLLABORATOR certificate')
signed_col_cert = sign_certificate(csr, signing_key, signing_crt.subject)
write_crt(signed_col_cert, f'{cert_name}.crt')
register_collaborator(CERT_DIR / 'client' / f'{file_name}.crt')
register_collaborator(cert_dir_path / 'client' / f'{file_name}.crt')

else:

Expand All @@ -379,7 +383,7 @@ def certify(collaborator_name, silent, request_pkg=None, import_=False, cert_pat
echo(' Signing COLLABORATOR certificate')
signed_col_cert = sign_certificate(csr, signing_key, signing_crt.subject)
write_crt(signed_col_cert, f'{cert_name}.crt')
register_collaborator(CERT_DIR / 'client' / f'{file_name}.crt')
register_collaborator(cert_dir_path / 'client' / f'{file_name}.crt')

else:
echo(style('Not signing certificate.', fg='red')
Expand All @@ -403,18 +407,18 @@ def certify(collaborator_name, silent, request_pkg=None, import_=False, cert_pat

Path(f'{tmp_dir}/client').mkdir(parents=True, exist_ok=True)
# Copy the signed cert to the temporary directory
copy(f'{CERT_DIR}/client/{file_name}.crt', f'{tmp_dir}/client/')
copy(f'{cert_dir_path}/client/{file_name}.crt', f'{tmp_dir}/client/')
# Copy the CA certificate chain to the temporary directory
copy(f'{CERT_DIR}/cert_chain.crt', tmp_dir)
copy(f'{cert_dir_path}/cert_chain.crt', tmp_dir)

# Create Zip archive of directory
make_archive(archive_name, archive_type, tmp_dir)

else:
# Copy the signed certificate and cert chain into PKI_DIR
previous_crts = glob(f'{CERT_DIR}/client/*.crt')
unpack_archive(import_, extract_dir=CERT_DIR)
updated_crts = glob(f'{CERT_DIR}/client/*.crt')
previous_crts = glob(f'{cert_dir_path}/client/*.crt')
unpack_archive(import_, extract_dir=cert_dir_path)
updated_crts = glob(f'{cert_dir_path}/client/*.crt')
cert_difference = list(set(updated_crts) - set(previous_crts))
if len(cert_difference) != 0:
crt = basename(cert_difference[0])
Expand Down
Loading

0 comments on commit 871aa83

Please sign in to comment.