diff --git a/tests/end_to_end/utils/common_fixtures.py b/tests/end_to_end/utils/common_fixtures.py index 0c96eafc91..a50446b234 100644 --- a/tests/end_to_end/utils/common_fixtures.py +++ b/tests/end_to_end/utils/common_fixtures.py @@ -8,12 +8,6 @@ import numpy as np import tests.end_to_end.utils.constants as constants -from tests.end_to_end.utils.wf_helper import ( - init_collaborator_private_attr_index, - init_collaborator_private_attr_name, - init_collaborate_pvt_attr_np, - init_agg_pvt_attr_np -) import tests.end_to_end.utils.federation_helper as fh import tests.end_to_end.utils.ssh_helper as ssh from tests.end_to_end.models import aggregator as agg_model, model_owner as mo_model @@ -259,10 +253,14 @@ def fx_local_federated_workflow(request): collaborators, and backend. """ # Import is done inline because Task Runner does not support importing below openfl packages - from openfl.experimental.workflow.interface import Aggregator, Collaborator from openfl.experimental.workflow.runtime import LocalRuntime - + from tests.end_to_end.utils.wf_helper import ( + init_collaborator_private_attr_index, + init_collaborator_private_attr_name, + init_collaborate_pvt_attr_np, + init_agg_pvt_attr_np + ) collab_callback_func = request.param[0] if hasattr(request, 'param') and request.param else None collab_value = request.param[1] if hasattr(request, 'param') and request.param else None agg_callback_func = request.param[2] if hasattr(request, 'param') and request.param else None @@ -318,10 +316,14 @@ def fx_local_federated_workflow_prvt_attr(request): collaborators, and backend. """ # Import is done inline because Task Runner does not support importing below openfl packages - from openfl.experimental.workflow.interface import Aggregator, Collaborator from openfl.experimental.workflow.runtime import LocalRuntime - + from tests.end_to_end.utils.wf_helper import ( + init_collaborator_private_attr_index, + init_collaborator_private_attr_name, + init_collaborate_pvt_attr_np, + init_agg_pvt_attr_np + ) collab_callback_func = request.param[0] if hasattr(request, 'param') and request.param else None collab_value = request.param[1] if hasattr(request, 'param') and request.param else None agg_callback_func = request.param[2] if hasattr(request, 'param') and request.param else None diff --git a/tests/end_to_end/utils/wf_helper.py b/tests/end_to_end/utils/wf_helper.py index 019d906ff5..fcde1118d0 100644 --- a/tests/end_to_end/utils/wf_helper.py +++ b/tests/end_to_end/utils/wf_helper.py @@ -1,9 +1,13 @@ +# Copyright 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from metaflow import Flow import logging import numpy as np log = logging.getLogger(__name__) + def validate_flow(flow_obj, expected_flow_steps): """ Validate: