Skip to content

Commit

Permalink
Use model name from test function name
Browse files Browse the repository at this point in the history
Signed-off-by: noopur <[email protected]>
  • Loading branch information
noopurintel committed Nov 14, 2024
1 parent 6044e30 commit 2aa9394
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 84 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/task_runner_e2e_wo_mtls.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ env:

jobs:
test_run:
name: tr
name: tr_wo_mtls
runs-on: ubuntu-22.04
timeout-minutes: 120 # 2 hours
strategy:
Expand Down
19 changes: 14 additions & 5 deletions tests/end_to_end/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,24 @@ pip install -r test-requirements.txt
To run a specific test case, use below command:

```sh
python -m pytest tests/end_to_end/test_suites/<test_case_filename> -k <marker> -s
python -m pytest -s tests/end_to_end/test_suites/<test_case_filename> -k <test_case_name>
```

** -s will ensure all the logs are printed on screen. Ignore, if not required.

To modify the number of collaborators, rounds to train and/or model name, use below parameters:
1. --num_collaborators
2. --num_rounds
3. --model_name
Below parameters are available for modification:

1. --num_collaborators <int> - to modify the number of collaborators
2. --num_rounds <int> - to modify the number of rounds to train
3. --model_name <str> - to use a specific model
4. --disable_tls - to disable TLS communication (by default it is enabled)
5. --disable_client_auth - to disable the client authentication (by default it is enabled)

For example, to run Task runner with - torch_cnn_mnist model, 3 collaborators, 5 rounds and non-TLS scenario:

```sh
python -m pytest -s tests/end_to_end/test_suites/task_runner_tests.py --num_rounds 5 --num_collaborators 3 --model_name torch_cnn_mnist --disable_tls
```

### Output Structure

Expand Down
113 changes: 57 additions & 56 deletions tests/end_to_end/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def pytest_addoption(parser):
"--model_name",
action="store",
type=str,
default=constants.DEFAULT_MODEL_NAME,
help="Model name",
)
parser.addoption(
Expand Down Expand Up @@ -209,7 +208,7 @@ def pytest_sessionfinish(session, exitstatus):
log.debug(f"Cleared .pytest_cache directory at {cache_dir}")


@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def fx_federation(request, pytestconfig):
"""
Fixture for federation. This fixture is used to create the model owner, aggregator, and collaborators.
Expand All @@ -221,14 +220,15 @@ def fx_federation(request, pytestconfig):
Returns:
federation_fixture: Named tuple containing the objects for model owner, aggregator, and collaborators
Note: As this is a module level fixture, thus no import is required at test level.
Note: As this is a function level fixture, thus no import is required at test level.
"""
collaborators = []
agg_domain_name = "localhost"

# Parse the command line arguments
args = parse_arguments()
model_name = args.model_name
# Use the model name from the test case name if not provided as a command line argument
model_name = args.model_name if args.model_name else request.node.name.split("test_")[1]
results_dir = args.results_dir or pytestconfig.getini("results_dir")
num_collaborators = args.num_collaborators
num_rounds = args.num_rounds
Expand All @@ -249,6 +249,7 @@ def fx_federation(request, pytestconfig):
raise ValueError(f"Invalid model name: {model_name}")

workspace_name = f"workspace_{model_name}"
log.info(f"Workspace name is: {workspace_name}")

# Create model owner object and the workspace for the model
model_owner = participants.ModelOwner(workspace_name, model_name)
Expand All @@ -259,55 +260,55 @@ def fx_federation(request, pytestconfig):
log.error(f"Failed to create the workspace: {e}")
raise e

# Modify the plan
try:
model_owner.modify_plan(new_rounds=num_rounds, num_collaborators=num_collaborators, disable_tls=disable_tls)
except Exception as e:
log.error(f"Failed to modify the plan: {e}")
raise e

# For TLS enabled (default) scenario: when the workspace is certified, the collaborators are registered as well
# For TLS disabled scenario: collaborators need to be registered explicitly
if args.disable_tls:
log.info("Disabling TLS for communication")
model_owner.register_collaborators(num_collaborators)
else:
log.info("Enabling TLS for communication")
try:
model_owner.certify_workspace()
except Exception as e:
log.error(f"Failed to certify the workspace: {e}")
raise e

# Initialize the plan
try:
model_owner.initialize_plan(agg_domain_name=agg_domain_name)
except Exception as e:
log.error(f"Failed to initialize the plan: {e}")
raise e

# Create the objects for aggregator and collaborators
aggregator = participants.Aggregator(
agg_domain_name=agg_domain_name, workspace_path=workspace_path
)

for i in range(num_collaborators):
collaborator = participants.Collaborator(
collaborator_name=f"collaborator{i+1}",
data_directory_path=i + 1,
workspace_path=workspace_path,
)
collaborator.create_collaborator()
collaborators.append(collaborator)

# Return the federation fixture
return federation_fixture(
model_owner=model_owner,
aggregator=aggregator,
collaborators=collaborators,
model_name=model_name,
disable_client_auth=disable_client_auth,
disable_tls=disable_tls,
workspace_path=workspace_path,
results_dir=results_dir,
)
# # Modify the plan
# try:
# model_owner.modify_plan(new_rounds=num_rounds, num_collaborators=num_collaborators, disable_tls=disable_tls)
# except Exception as e:
# log.error(f"Failed to modify the plan: {e}")
# raise e

# # For TLS enabled (default) scenario: when the workspace is certified, the collaborators are registered as well
# # For TLS disabled scenario: collaborators need to be registered explicitly
# if args.disable_tls:
# log.info("Disabling TLS for communication")
# model_owner.register_collaborators(num_collaborators)
# else:
# log.info("Enabling TLS for communication")
# try:
# model_owner.certify_workspace()
# except Exception as e:
# log.error(f"Failed to certify the workspace: {e}")
# raise e

# # Initialize the plan
# try:
# model_owner.initialize_plan(agg_domain_name=agg_domain_name)
# except Exception as e:
# log.error(f"Failed to initialize the plan: {e}")
# raise e

# # Create the objects for aggregator and collaborators
# aggregator = participants.Aggregator(
# agg_domain_name=agg_domain_name, workspace_path=workspace_path
# )

# for i in range(num_collaborators):
# collaborator = participants.Collaborator(
# collaborator_name=f"collaborator{i+1}",
# data_directory_path=i + 1,
# workspace_path=workspace_path,
# )
# collaborator.create_collaborator()
# collaborators.append(collaborator)

# # Return the federation fixture
# return federation_fixture(
# model_owner=model_owner,
# aggregator=aggregator,
# collaborators=collaborators,
# model_name=model_name,
# disable_client_auth=disable_client_auth,
# disable_tls=disable_tls,
# workspace_path=workspace_path,
# results_dir=results_dir,
# )
42 changes: 21 additions & 21 deletions tests/end_to_end/test_suites/task_runner_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,30 @@ def test_torch_cnn_mnist(fx_federation):
"""
log.info("Testing torch_cnn_mnist model")

# Setup PKI for trusted communication within the federation
if not fx_federation.disable_tls:
assert fed_helper.setup_pki(fx_federation), "Failed to setup PKI for trusted communication"
# # Setup PKI for trusted communication within the federation
# if not fx_federation.disable_tls:
# assert fed_helper.setup_pki(fx_federation), "Failed to setup PKI for trusted communication"

# Start the federation
results = fed_helper.run_federation(fx_federation)
# # Start the federation
# results = fed_helper.run_federation(fx_federation)

# Verify the completion of the federation run
assert fed_helper.verify_federation_run_completion(fx_federation, results), "Federation completion failed"
# # Verify the completion of the federation run
# assert fed_helper.verify_federation_run_completion(fx_federation, results), "Federation completion failed"


@pytest.mark.keras_cnn_mnist
def test_keras_cnn_mnist(fx_federation):
log.info("Testing keras_cnn_mnist model")

# Setup PKI for trusted communication within the federation
if not fx_federation.disable_tls:
assert fed_helper.setup_pki(fx_federation), "Failed to setup PKI for trusted communication"
# # Setup PKI for trusted communication within the federation
# if not fx_federation.disable_tls:
# assert fed_helper.setup_pki(fx_federation), "Failed to setup PKI for trusted communication"

# Start the federation
results = fed_helper.run_federation(fx_federation)
# # Start the federation
# results = fed_helper.run_federation(fx_federation)

# Verify the completion of the federation run
assert fed_helper.verify_federation_run_completion(fx_federation, results), "Federation completion failed"
# # Verify the completion of the federation run
# assert fed_helper.verify_federation_run_completion(fx_federation, results), "Federation completion failed"


@pytest.mark.torch_cnn_histology
Expand All @@ -49,12 +49,12 @@ def test_torch_cnn_histology(fx_federation):
"""
log.info("Testing torch_cnn_histology model")

# Setup PKI for trusted communication within the federation
if not fx_federation.disable_tls:
assert fed_helper.setup_pki(fx_federation), "Failed to setup PKI for trusted communication"
# # Setup PKI for trusted communication within the federation
# if not fx_federation.disable_tls:
# assert fed_helper.setup_pki(fx_federation), "Failed to setup PKI for trusted communication"

# Start the federation
results = fed_helper.run_federation(fx_federation)
# # Start the federation
# results = fed_helper.run_federation(fx_federation)

# Verify the completion of the federation run
assert fed_helper.verify_federation_run_completion(fx_federation, results), "Federation completion failed"
# # Verify the completion of the federation run
# assert fed_helper.verify_federation_run_completion(fx_federation, results), "Federation completion failed"
2 changes: 1 addition & 1 deletion tests/end_to_end/utils/conftest_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def parse_arguments():
parser.add_argument("--results_dir", type=str, required=False, default="results", help="Directory to store the results")
parser.add_argument("--num_collaborators", type=int, default=2, help="Number of collaborators")
parser.add_argument("--num_rounds", type=int, default=5, help="Number of rounds to train")
parser.add_argument("--model_name", type=str, default="torch_cnn_mnist", help="Model name")
parser.add_argument("--model_name", type=str, help="Model name")
parser.add_argument("--disable_client_auth", action="store_true", help="Disable client authentication")
parser.add_argument("--disable_tls", action="store_true", help="Disable TLS for communication")
args = parser.parse_known_args()[0]
Expand Down

0 comments on commit 2aa9394

Please sign in to comment.