Skip to content

Commit

Permalink
refining tests
Browse files Browse the repository at this point in the history
Signed-off-by: Mansi Sharma <[email protected]>
  • Loading branch information
Mansi Sharma committed Jan 4, 2023
1 parent e6ff4fe commit c519c91
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 22 deletions.
16 changes: 16 additions & 0 deletions .github/workflows/pki.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,20 @@ jobs:
- name: Test PKI
run: |
python tests/github/pki_wrong_cn.py
test_pki_cert_location:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- name: Set up Python 3.8
uses: actions/setup-python@v3
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .
- name: Test PKI
run: |
python tests/github/test_pki_cert_location.py
100 changes: 100 additions & 0 deletions tests/github/test_pki_cert_location.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (C) 2020-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import os
import time
import socket
import argparse
from pathlib import Path
from subprocess import check_call
from concurrent.futures import ProcessPoolExecutor

from openfl.utilities.utils import rmtree
from tests.github.utils import create_collaborator, create_certified_workspace, certify_aggregator


if __name__ == '__main__':
# Test the pipeline
parser = argparse.ArgumentParser()
workspace_choice = []
with os.scandir('openfl-workspace') as iterator:
for entry in iterator:
if entry.name not in ['__init__.py', 'workspace', 'default']:
workspace_choice.append(entry.name)
parser.add_argument('--template', default='keras_cnn_mnist', choices=workspace_choice)
parser.add_argument('--fed_workspace', default='fed_work12345alpha81671')
parser.add_argument('--col1', default='one123dragons')
parser.add_argument('--col2', default='beta34unicorns')
parser.add_argument('--rounds-to-train')
parser.add_argument('--col1-data-path', default='1')
parser.add_argument('--col2-data-path', default='2')
parser.add_argument('--agg-cert-path', default=Path.cwd())
parser.add_argument('--col1-cert-path', default=Path.cwd())
parser.add_argument('--col2-cert-path', default=Path.cwd())
parser.add_argument('--save-model')

origin_dir = Path.cwd().resolve()
args = parser.parse_args()
fed_workspace = args.fed_workspace
archive_name = f'{fed_workspace}.zip'
fqdn = socket.getfqdn()
template = args.template
rounds_to_train = args.rounds_to_train
col1, col2 = args.col1, args.col2
col1_data_path, col2_data_path = args.col1_data_path, args.col2_data_path
agg_cert_path = Path(args.agg_cert_path).resolve()
col1_cert_path = Path(args.col1_cert_path).resolve()
col2_cert_path = Path(args.col2_cert_path).resolve()
save_model = args.save_model

# START
# =====
# Make sure you are in a Python virtual environment with the FL package installed.
create_certified_workspace(fed_workspace, template, fqdn, rounds_to_train, agg_cert_path)
certify_aggregator(fqdn, agg_cert_path)

workspace_root = Path().resolve() # Get the absolute directory path for the workspace

# Create collaborator #1
create_collaborator(col1, workspace_root, col1_data_path, archive_name,
fed_workspace, col1_cert_path, agg_cert_path)

# Create collaborator #2
create_collaborator(col2, workspace_root, col2_data_path, archive_name,
fed_workspace, col2_cert_path, agg_cert_path)

# Run the federation
with ProcessPoolExecutor(max_workers=3) as executor:
executor.submit(
check_call, ['fx', 'aggregator', 'start', '-c', agg_cert_path],
cwd=workspace_root)
time.sleep(5)

dir1 = workspace_root / col1 / fed_workspace
executor.submit(
check_call, ['fx', 'collaborator', 'start', '-n', col1, '-c', col1_cert_path],
cwd=dir1)

dir2 = workspace_root / col2 / fed_workspace
executor.submit(
check_call, ['fx', 'collaborator', 'start', '-n', col2, '-c', col2_cert_path],
cwd=dir2)

# Convert model to native format
if save_model:
check_call(
['fx', 'model', 'save', '-i', f'./save/{template}_last.pbuf', '-o', save_model],
cwd=workspace_root)

# Clear cert paths
check_call(
['fx', 'aggregator', 'uninstall-cert', '-c', agg_cert_path],
cwd=workspace_root)
check_call(
['fx', 'collaborator', 'uninstall-cert', '-c', col1_cert_path],
cwd=workspace_root / col1 / fed_workspace)
check_call(
['fx', 'collaborator', 'uninstall-cert', '-c', col2_cert_path],
cwd=workspace_root / col2 / fed_workspace)
os.chdir(origin_dir)
rmtree(workspace_root)
73 changes: 51 additions & 22 deletions tests/github/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import tarfile


def create_collaborator(col, workspace_root, data_path, archive_name, fed_workspace):
def create_collaborator(col, workspace_root, data_path, archive_name, fed_workspace,
cert_path=None, ca_cert_path=None):
# Copy workspace to collaborator directories (these can be on different machines)
col_path = workspace_root / col
shutil.rmtree(col_path, ignore_errors=True) # Remove any existing directory
Expand All @@ -22,27 +23,47 @@ def create_collaborator(col, workspace_root, data_path, archive_name, fed_worksp

# Create collaborator certificate request
# Remove '--silent' if you run this manually
check_call(
['fx', 'collaborator', 'generate-cert-request', '-d', data_path, '-n', col, '--silent'],
cwd=col_path / fed_workspace
)
if cert_path:
check_call(
['fx', 'collaborator', 'generate-cert-request', '-d', data_path,
'-n', col, '-c', cert_path, '--silent'],
cwd=col_path / fed_workspace
)
else:
check_call(
['fx', 'collaborator', 'generate-cert-request', '-d', data_path,
'-n', col, '--silent'],
cwd=col_path / fed_workspace
)

# Sign collaborator certificate
# Remove '--silent' if you run this manually
request_pkg = col_path / fed_workspace / f'col_{col}_to_agg_cert_request.zip'
check_call(
['fx', 'collaborator', 'certify', '--request-pkg', str(request_pkg), '--silent'],
cwd=workspace_root)
if ca_cert_path:
check_call(
['fx', 'collaborator', 'certify', '--request-pkg', str(request_pkg),
'-c', ca_cert_path, '--silent'],
cwd=workspace_root)
else:
check_call(
['fx', 'collaborator', 'certify', '--request-pkg', str(request_pkg), '--silent'],
cwd=workspace_root)

# Import the signed certificate from the aggregator
import_path = workspace_root / f'agg_to_col_{col}_signed_cert.zip'
check_call(
['fx', 'collaborator', 'certify', '--import', import_path],
cwd=col_path / fed_workspace
)


def create_certified_workspace(path, template, fqdn, rounds_to_train):
if cert_path:
check_call(
['fx', 'collaborator', 'certify', '--import', import_path, '-c', cert_path],
cwd=col_path / fed_workspace
)
else:
check_call(
['fx', 'collaborator', 'certify', '--import', import_path],
cwd=col_path / fed_workspace
)


def create_certified_workspace(path, template, fqdn, rounds_to_train, cert_path=None):
shutil.rmtree(path, ignore_errors=True)
check_call(['fx', 'workspace', 'create', '--prefix', path, '--template', template])
os.chdir(path)
Expand All @@ -62,18 +83,26 @@ def create_certified_workspace(path, template, fqdn, rounds_to_train):
except (ValueError, TypeError):
pass
# Create certificate authority for workspace
check_call(['fx', 'workspace', 'certify'])
if cert_path:
check_call(['fx', 'workspace', 'certify', '-c', cert_path])
else:
check_call(['fx', 'workspace', 'certify'])

# Export FL workspace
check_call(['fx', 'workspace', 'export'])


def certify_aggregator(fqdn):
# Create aggregator certificate
check_call(['fx', 'aggregator', 'generate-cert-request', '--fqdn', fqdn])

# Sign aggregator certificate
check_call(['fx', 'aggregator', 'certify', '--fqdn', fqdn, '--silent'])
def certify_aggregator(fqdn, cert_path=None):
if cert_path:
# Create aggregator certificate
check_call(['fx', 'aggregator', 'generate-cert-request', '--fqdn', fqdn, '-c', cert_path])
# Sign aggregator certificate
check_call(['fx', 'aggregator', 'certify', '--fqdn', fqdn, '-c', cert_path, '--silent'])
else:
# Create aggregator certificate
check_call(['fx', 'aggregator', 'generate-cert-request', '--fqdn', fqdn])
# Sign aggregator certificate
check_call(['fx', 'aggregator', 'certify', '--fqdn', fqdn, '--silent'])


def create_signed_cert_for_collaborator(col, data_path):
Expand Down

0 comments on commit c519c91

Please sign in to comment.