Skip to content

Commit

Permalink
Import within the local federated fixtures
Browse files Browse the repository at this point in the history
Signed-off-by: noopur <[email protected]>
  • Loading branch information
noopurintel committed Dec 18, 2024
1 parent c462f8e commit c581320
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
22 changes: 12 additions & 10 deletions tests/end_to_end/utils/common_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@
import numpy as np

import tests.end_to_end.utils.constants as constants
from tests.end_to_end.utils.wf_helper import (
init_collaborator_private_attr_index,
init_collaborator_private_attr_name,
init_collaborate_pvt_attr_np,
init_agg_pvt_attr_np
)
import tests.end_to_end.utils.federation_helper as fh
import tests.end_to_end.utils.ssh_helper as ssh
from tests.end_to_end.models import aggregator as agg_model, model_owner as mo_model
Expand Down Expand Up @@ -259,10 +253,14 @@ def fx_local_federated_workflow(request):
collaborators, and backend.
"""
# Import is done inline because Task Runner does not support importing below openfl packages

from openfl.experimental.workflow.interface import Aggregator, Collaborator
from openfl.experimental.workflow.runtime import LocalRuntime

from tests.end_to_end.utils.wf_helper import (
init_collaborator_private_attr_index,
init_collaborator_private_attr_name,
init_collaborate_pvt_attr_np,
init_agg_pvt_attr_np
)
collab_callback_func = request.param[0] if hasattr(request, 'param') and request.param else None
collab_value = request.param[1] if hasattr(request, 'param') and request.param else None
agg_callback_func = request.param[2] if hasattr(request, 'param') and request.param else None
Expand Down Expand Up @@ -318,10 +316,14 @@ def fx_local_federated_workflow_prvt_attr(request):
collaborators, and backend.
"""
# Import is done inline because Task Runner does not support importing below openfl packages

from openfl.experimental.workflow.interface import Aggregator, Collaborator
from openfl.experimental.workflow.runtime import LocalRuntime

from tests.end_to_end.utils.wf_helper import (
init_collaborator_private_attr_index,
init_collaborator_private_attr_name,
init_collaborate_pvt_attr_np,
init_agg_pvt_attr_np
)
collab_callback_func = request.param[0] if hasattr(request, 'param') and request.param else None
collab_value = request.param[1] if hasattr(request, 'param') and request.param else None
agg_callback_func = request.param[2] if hasattr(request, 'param') and request.param else None
Expand Down
4 changes: 4 additions & 0 deletions tests/end_to_end/utils/wf_helper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Copyright 2020-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from metaflow import Flow
import logging
import numpy as np

log = logging.getLogger(__name__)


def validate_flow(flow_obj, expected_flow_steps):
"""
Validate:
Expand Down

0 comments on commit c581320

Please sign in to comment.