diff --git a/tests/end_to_end/conftest.py b/tests/end_to_end/conftest.py index efe9febbc6..193609034e 100644 --- a/tests/end_to_end/conftest.py +++ b/tests/end_to_end/conftest.py @@ -249,66 +249,68 @@ def fx_federation(request, pytestconfig): raise ValueError(f"Invalid model name: {model_name}") workspace_name = f"workspace_{model_name}" - log.info(f"Workspace name is: {workspace_name}") # Create model owner object and the workspace for the model model_owner = participants.ModelOwner(workspace_name, model_name) - try: workspace_path = model_owner.create_workspace(results_dir=results_dir) except Exception as e: log.error(f"Failed to create the workspace: {e}") raise e - # # Modify the plan - # try: - # model_owner.modify_plan(new_rounds=num_rounds, num_collaborators=num_collaborators, disable_tls=disable_tls) - # except Exception as e: - # log.error(f"Failed to modify the plan: {e}") - # raise e - - # # For TLS enabled (default) scenario: when the workspace is certified, the collaborators are registered as well - # # For TLS disabled scenario: collaborators need to be registered explicitly - # if args.disable_tls: - # log.info("Disabling TLS for communication") - # model_owner.register_collaborators(num_collaborators) - # else: - # log.info("Enabling TLS for communication") - # try: - # model_owner.certify_workspace() - # except Exception as e: - # log.error(f"Failed to certify the workspace: {e}") - # raise e - - # # Initialize the plan - # try: - # model_owner.initialize_plan(agg_domain_name=agg_domain_name) - # except Exception as e: - # log.error(f"Failed to initialize the plan: {e}") - # raise e - - # # Create the objects for aggregator and collaborators - # aggregator = participants.Aggregator( - # agg_domain_name=agg_domain_name, workspace_path=workspace_path - # ) - - # for i in range(num_collaborators): - # collaborator = participants.Collaborator( - # collaborator_name=f"collaborator{i+1}", - # data_directory_path=i + 1, - # workspace_path=workspace_path, - # ) - # collaborator.create_collaborator() - # collaborators.append(collaborator) - - # # Return the federation fixture - # return federation_fixture( - # model_owner=model_owner, - # aggregator=aggregator, - # collaborators=collaborators, - # model_name=model_name, - # disable_client_auth=disable_client_auth, - # disable_tls=disable_tls, - # workspace_path=workspace_path, - # results_dir=results_dir, - # ) + # Modify the plan + try: + model_owner.modify_plan(new_rounds=num_rounds, num_collaborators=num_collaborators, disable_tls=disable_tls) + except Exception as e: + log.error(f"Failed to modify the plan: {e}") + raise e + + # For TLS enabled (default) scenario: when the workspace is certified, the collaborators are registered as well + # For TLS disabled scenario: collaborators need to be registered explicitly + if args.disable_tls: + log.info("Disabling TLS for communication") + try: + model_owner.register_collaborators(num_collaborators) + except Exception as e: + log.error(f"Failed to register the collaborators: {e}") + raise e + else: + log.info("Enabling TLS for communication") + try: + model_owner.certify_workspace() + except Exception as e: + log.error(f"Failed to certify the workspace: {e}") + raise e + + # Initialize the plan + try: + model_owner.initialize_plan(agg_domain_name=agg_domain_name) + except Exception as e: + log.error(f"Failed to initialize the plan: {e}") + raise e + + # Create the objects for aggregator and collaborators + aggregator = participants.Aggregator( + agg_domain_name=agg_domain_name, workspace_path=workspace_path + ) + + for i in range(num_collaborators): + collaborator = participants.Collaborator( + collaborator_name=f"collaborator{i+1}", + data_directory_path=i + 1, + workspace_path=workspace_path, + ) + collaborator.create_collaborator() + collaborators.append(collaborator) + + # Return the federation fixture + return federation_fixture( + model_owner=model_owner, + aggregator=aggregator, + collaborators=collaborators, + model_name=model_name, + disable_client_auth=disable_client_auth, + disable_tls=disable_tls, + workspace_path=workspace_path, + results_dir=results_dir, + ) diff --git a/tests/end_to_end/models/participants.py b/tests/end_to_end/models/participants.py index 7c49a2a7ac..0ef22d8f28 100644 --- a/tests/end_to_end/models/participants.py +++ b/tests/end_to_end/models/participants.py @@ -190,27 +190,31 @@ def register_collaborators(self, num_collaborators=None): bool: True if successful, else False """ self.cols_path = os.path.join(self.workspace_path, "plan", "cols.yaml") - log.info(f"Registering the collaborators in {self.cols_path}") - # Open the file and modify the entries + log.info(f"Registering the collaborators..") self.num_collaborators = num_collaborators if num_collaborators else self.num_collaborators - # Straightforward writing to the yaml file is not recommended here - # As the file might contain spaces and tabs which can cause issues - with open(self.cols_path, "r", encoding="utf-8") as f: - doc = yaml.load(f, Loader=yaml.FullLoader) + try: + # Straightforward writing to the yaml file is not recommended here + # As the file might contain spaces and tabs which can cause issues + with open(self.cols_path, "r", encoding="utf-8") as f: + doc = yaml.load(f, Loader=yaml.FullLoader) - if "collaborators" not in doc.keys() or not doc["collaborators"]: - doc["collaborators"] = [] # Create empty list + if "collaborators" not in doc.keys() or not doc["collaborators"]: + doc["collaborators"] = [] # Create empty list - for i in range(num_collaborators): - col_name = "collaborator" + str(i+1) - doc["collaborators"].append(col_name) - with open(self.cols_path, "w", encoding="utf-8") as f: - yaml.dump(doc, f) + for i in range(num_collaborators): + col_name = "collaborator" + str(i+1) + doc["collaborators"].append(col_name) + with open(self.cols_path, "w", encoding="utf-8") as f: + yaml.dump(doc, f) - log.info( - f"Modified the plan to train the model for collaborators {self.num_collaborators} and {self.rounds_to_train} rounds" - ) + log.info( + f"Successfully registered collaborators in {self.cols_path}" + ) + except Exception as e: + log.error(f"Failed to register the collaborators: {e}") + raise e + return True def certify_aggregator(self, agg_domain_name): """ diff --git a/tests/end_to_end/test_suites/sample_tests.py b/tests/end_to_end/test_suites/sample_tests.py index 7c528277e8..a27bf76cbf 100644 --- a/tests/end_to_end/test_suites/sample_tests.py +++ b/tests/end_to_end/test_suites/sample_tests.py @@ -19,8 +19,8 @@ # 7. Start the federation using aggregator and given no of collaborators. # 8. Verify the completion of the federation run. -@pytest.mark.sample_model -def test_sample_model(fx_federation): +@pytest.mark.sample_model_name +def test_sample_model_name(fx_federation): """ Add a proper docstring here. """ diff --git a/tests/end_to_end/test_suites/task_runner_tests.py b/tests/end_to_end/test_suites/task_runner_tests.py index 383c09231f..371fee8f08 100644 --- a/tests/end_to_end/test_suites/task_runner_tests.py +++ b/tests/end_to_end/test_suites/task_runner_tests.py @@ -16,30 +16,30 @@ def test_torch_cnn_mnist(fx_federation): """ log.info("Testing torch_cnn_mnist model") - # # Setup PKI for trusted communication within the federation - # if not fx_federation.disable_tls: - # assert fed_helper.setup_pki(fx_federation), "Failed to setup PKI for trusted communication" + # Setup PKI for trusted communication within the federation + if not fx_federation.disable_tls: + assert fed_helper.setup_pki(fx_federation), "Failed to setup PKI for trusted communication" - # # Start the federation - # results = fed_helper.run_federation(fx_federation) + # Start the federation + results = fed_helper.run_federation(fx_federation) - # # Verify the completion of the federation run - # assert fed_helper.verify_federation_run_completion(fx_federation, results), "Federation completion failed" + # Verify the completion of the federation run + assert fed_helper.verify_federation_run_completion(fx_federation, results), "Federation completion failed" @pytest.mark.keras_cnn_mnist def test_keras_cnn_mnist(fx_federation): log.info("Testing keras_cnn_mnist model") - # # Setup PKI for trusted communication within the federation - # if not fx_federation.disable_tls: - # assert fed_helper.setup_pki(fx_federation), "Failed to setup PKI for trusted communication" + # Setup PKI for trusted communication within the federation + if not fx_federation.disable_tls: + assert fed_helper.setup_pki(fx_federation), "Failed to setup PKI for trusted communication" - # # Start the federation - # results = fed_helper.run_federation(fx_federation) + # Start the federation + results = fed_helper.run_federation(fx_federation) - # # Verify the completion of the federation run - # assert fed_helper.verify_federation_run_completion(fx_federation, results), "Federation completion failed" + # Verify the completion of the federation run + assert fed_helper.verify_federation_run_completion(fx_federation, results), "Federation completion failed" @pytest.mark.torch_cnn_histology @@ -49,12 +49,12 @@ def test_torch_cnn_histology(fx_federation): """ log.info("Testing torch_cnn_histology model") - # # Setup PKI for trusted communication within the federation - # if not fx_federation.disable_tls: - # assert fed_helper.setup_pki(fx_federation), "Failed to setup PKI for trusted communication" + # Setup PKI for trusted communication within the federation + if not fx_federation.disable_tls: + assert fed_helper.setup_pki(fx_federation), "Failed to setup PKI for trusted communication" - # # Start the federation - # results = fed_helper.run_federation(fx_federation) + # Start the federation + results = fed_helper.run_federation(fx_federation) - # # Verify the completion of the federation run - # assert fed_helper.verify_federation_run_completion(fx_federation, results), "Federation completion failed" + # Verify the completion of the federation run + assert fed_helper.verify_federation_run_completion(fx_federation, results), "Federation completion failed"