diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 65fd431315..eaf0656678 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,101 +1,78 @@ # Contributing to OpenFL -We welcome contributions from the community. We believe that anyone can bring something valuable to OpenFL and help us to improve the project. This document explains how to contribute to OpenFL. +We welcome contributions from the community. There are several ways to contribute: +* Improvements in [documentation](https://openfl.readthedocs.io/en/latest/install.html). +* Contributing to OpenFL's code-base: via bug-fixes or feature additions. +* Answering questions on our [discussions page](https://github.com/securefederatedai/openfl/discussions). +* Participating in our [roadmap](https://github.com/securefederatedai/openfl/blob/develop/ROADMAP.md) discussions. -We accept various contributions from documentation improvement and bug fixing to major features proposals and [roadmap](https://github.com/intel/openfl/blob/develop/ROADMAP.md) suggestions. +We have a slack [channel](https://join.slack.com/t/openfl/shared_invite/zt-ovzbohvn-T5fApk05~YS_iZhjJ5yaTw) and we host regular [community meetings](https://github.com/intel/openfl#support). -Documentation improvement: review our [documentation](https://openfl.readthedocs.io/en/latest/install.html) and let us know if something is not clear or not relevant. -Propose your own formulations or even write new section explaining something that you know how works, but do not see in the documentation. -Propose it through GitHub [issues](https://github.com/intel/openfl/issues/new/choose) or [Discussions](https://github.com/intel/openfl/discussions). -To propose bugs, new features, or other code improvements: +## How to contribute code +### Step 1. Open an issue -1. Check open and closed [issues](https://github.com/intel/openfl/issues) and make sure there is no similar proposal. -2. Open a [new issue](https://github.com/intel/openfl/issues/new/choose), select a relevant category (Bug report / Feature request / Report a security vulnerability) and describe your idea using the template. -3. If you want to fix a bug or create this feature by yourself, prepare a contribution. - - Format your code following the [flake8 style](https://flake8.pycqa.org/en/latest/). - - Make sure that your code is original and corresponds to [OpenFL license](#license). - - Sing your work - [see below](#sign-your-work). - - Create a [pull request](#formatting-of-pull-requests) and wait for feedback. - - Verify that all tests in our [CI/CD pipeline](#Continuous-Integration-and-Continuous-Development) passed. -4. Hurrah! You are a new contributor to OpenFL! You will see your name in released notes of the subsequent releases!😊 +Before you start making any changes, it is always good to open an [issue](https://github.com/securefederatedai/openfl/issues/new/choose) first (assuming one does not already exist), outlining your proposed changes. We can give you feedback, and potentially validate the proposed changes. -Join our [Slack](https://join.slack.com/t/openfl/shared_invite/zt-ovzbohvn-T5fApk05~YS_iZhjJ5yaTw) and [Community meetings](https://github.com/intel/openfl#support) and participate in the discussions. +For minor changes (akin to a documentation or bug fix), proceed to opening a Pull Request (PR) directly. -Are you an expert in Federated Learning and want to contribute to our roadmap? You can nominate yourself as a member of our Technical Steering Committee and be part of the OpenFL decision making group. Please reach us through our [Slack](https://join.slack.com/t/openfl/shared_invite/zt-ovzbohvn-T5fApk05~YS_iZhjJ5yaTw). +### Step 2. Make code changes -### Code format and style +To modify code, you need to fork the repository. Set up a development environment as covered in the section "Setup environment" below. -We use [flake8](https://flake8.pycqa.org/en/latest/) for PEP8 style guide enforcement. This is run as a part of our CI/CD pipeline and it’s required prior a merge. +### Step 3. Create a Pull Request (PR) -### Formatting of Pull Requests +Once the change is ready, open a PR from your branch in your fork, to the `develop` branch in [securefederatedai/openfl](https://github.com/securefederatedai/openfl). OpenFL follows standard recommendations of PR formatting. Find more details [here](https://github.blog/2015-01-21-how-to-write-the-perfect-pull-request/). -OpenFL follows standard recommendations of PR formatting. Please find more details [here](https://github.blog/2015-01-21-how-to-write-the-perfect-pull-request/). +### Step 4. Sign your work -### Continuous Integration and Continuous Development +Signoff your patch commits using your real name. We discourage anonymous contributions. -OpenFL uses GitHub actions to perform all functional and unit tests. Before your contribution can be merged make sure that all your tests are passing. -For more information of what fails you can click on the β€œdetails” link near the pipeline that failed. - -![CI/CD](docs/images/CI_details.png) - -### Writing the tests + Signed-off-by: Joe Smith -The OpenFL team recommend including tests for all new features contributions. Test can be found in the β€œTests” directory. -The [Tests/OpenFL folder](https://github.com/intel/openfl/tree/develop/tests/openfl) contains unit tests and the [Tests/GitHub folder](https://github.com/intel/openfl/tree/develop/tests/github) contains end-to-end and functional tests. +If you set your `user.name` and `user.email` git configs, you can sign your +commits using `git commit --signoff`. -### License +Your signature [certifies](http://developercertificate.org/) that you wrote the patch, or, you otherwise have the right to pass it on as an open-source patch. -OpenFL is licensed under the terms in [Apache 2.0 license](https://github.com/intel/openfl/blob/develop/LICENSE). By contributing to the project, you agree to the license and copyright terms therein and release your contribution under these terms. +OpenFL is licensed under the [Apache 2.0 license](https://github.com/intel/openfl/blob/develop/LICENSE). By contributing to the project, you agree to the license and copyright terms therein and release your contribution under these terms. -### Sign your work +### Step 5. Code review and merge -Please use the sign-off line at the end of the patch. Your signature certifies that you wrote the patch or otherwise have the right to pass it on as an open-source patch. The rules are pretty simple: if you can certify -the below (from [developercertificate.org](http://developercertificate.org/)): +Verify that your contribution passes all tests in our CI/CD pipeline. In case of a failure, like shown below, look into the error messages and try to fix them. -``` -Developer Certificate of Origin -Version 1.1 +![CI/CD](docs/images/CI_details.png) -Copyright (C) 2004, 2006 The Linux Foundation and its contributors. -660 York Street, Suite 102, -San Francisco, CA 94110 USA +Meanwhile, a reviewer will review the pull request and provide comments. Post few iterations of +reviews and changes (depending on the complexity of the changes), PR will be approved for merge. -Everyone is permitted to copy and distribute verbatim copies of this -license document, but changing it is not allowed. +## Setup environment -Developer's Certificate of Origin 1.1 +We recommend setting up a local dev environment. Clone your forked repo to your local machine and install the dependencies. -By making a contribution to this project, I certify that: +```shell +git clone https://github.com/YOUR_GITHUB_USERNAME/openfl.git +cd openfl +pip install -U pip setuptools wheel +pip install . +``` -(a) The contribution was created in whole or in part by me and I - have the right to submit it under the open source license - indicated in the file; or +## Code style -(b) The contribution is based upon previous work that, to the best - of my knowledge, is covered under an appropriate open source - license and I have the right under that license to submit that - work with modifications, whether created in whole or in part - by me, under the same open source license (unless I am - permitted to submit under a different license), as indicated - in the file; or +OpenFL uses [black](https://black.readthedocs.io/en/stable/) and [isort](https://pycqa.github.io/isort/) to format the code. -(c) The contribution was provided directly to me by some other - person who certified (a), (b) or (c) and I have not modified - it. +Run the following command at the **root** directory of the repo to format your code. -(d) I understand and agree that this project and the contribution - are public and that a record of the contribution (including all - personal information I submit with it, including my sign-off) is - maintained indefinitely and may be redistributed consistent with - this project or the open source license(s) involved. ``` +sh shell/format.sh +``` +You may need to resolve errors that could not be resolved by autoformatting. To only show lint errors, run `sh shell/lint.sh` at the **root** directory of the repo. -Then you just add a line to every git commit message: - - Signed-off-by: Joe Smith - -Use your real name (sorry, no pseudonyms or anonymous contributions.) +### Docstrings +Since docstrings cannot be checked or standardized, if you do write/edit any docstring, make sure to check them manually. OpenFL docstrings should follow the conventions below: -If you set your `user.name` and `user.email` git configs, you can sign your -commit automatically with `git commit -s`. +A **class** or a **function** docstring may contain: +* A one-line description of the class/function. +* Paragraph(s) of detailed information. +* Optional `Examples` section. +* `Args` section for arguments under `__init__()`. \ No newline at end of file diff --git a/openfl/experimental/component/__init__.py b/openfl/experimental/component/__init__.py index 6b815db0c7..8bb0c3871a 100644 --- a/openfl/experimental/component/__init__.py +++ b/openfl/experimental/component/__init__.py @@ -1,9 +1,7 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """openfl.experimental.component package.""" -from .aggregator import Aggregator -from .collaborator import Collaborator - -__all__ = ["Aggregator", "Collaborator"] +# FIXME: Too much recursion +from openfl.experimental.component.aggregator import Aggregator +from openfl.experimental.component.collaborator import Collaborator diff --git a/openfl/experimental/component/aggregator/__init__.py b/openfl/experimental/component/aggregator/__init__.py index 34e42f18f2..6686ce37b8 100644 --- a/openfl/experimental/component/aggregator/__init__.py +++ b/openfl/experimental/component/aggregator/__init__.py @@ -1,8 +1,6 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """openfl.experimental.component.aggregator package.""" -from .aggregator import Aggregator - -__all__ = ["Aggregator",] +# FIXME: Too much recursion. +from openfl.experimental.component.aggregator import Aggregator diff --git a/openfl/experimental/component/aggregator/aggregator.py b/openfl/experimental/component/aggregator/aggregator.py index 977753a26b..116748903b 100644 --- a/openfl/experimental/component/aggregator/aggregator.py +++ b/openfl/experimental/component/aggregator/aggregator.py @@ -1,19 +1,16 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Experimental Aggregator module.""" -import time -import queue -import pickle import inspect -from threading import Event +import pickle +import queue +import time from logging import getLogger -from typing import Any, Callable -from typing import Dict, List, Tuple +from threading import Event +from typing import Any, Callable, Dict, List, Tuple -from openfl.experimental.utilities import aggregator_to_collaborator from openfl.experimental.runtime import FederatedRuntime -from openfl.experimental.utilities import checkpoint +from openfl.experimental.utilities import aggregator_to_collaborator, checkpoint from openfl.experimental.utilities.metaflow_utils import MetaflowInterface @@ -37,21 +34,19 @@ class Aggregator: """ def __init__( - self, - aggregator_uuid: str, - federation_uuid: str, - authorized_cols: List, - - flow: Any, - rounds_to_train: int = 1, - checkpoint: bool = False, - private_attributes_callable: Callable = None, - private_attributes_kwargs: Dict = {}, - - single_col_cert_common_name: str = None, - - log_metric_callback: Callable = None, - **kwargs) -> None: + self, + aggregator_uuid: str, + federation_uuid: str, + authorized_cols: List, + flow: Any, + rounds_to_train: int = 1, + checkpoint: bool = False, + private_attributes_callable: Callable = None, + private_attributes_kwargs: Dict = {}, + single_col_cert_common_name: str = None, + log_metric_callback: Callable = None, + **kwargs, + ) -> None: self.logger = getLogger(__name__) @@ -81,8 +76,9 @@ def __init__( # Event to inform aggregator that collaborators have sent the results self.collaborator_task_results = Event() # A queue for each task - self.__collaborator_tasks_queue = {collab: queue.Queue() for collab - in self.authorized_cols} + self.__collaborator_tasks_queue = { + collab: queue.Queue() for collab in self.authorized_cols + } self.flow = flow self.checkpoint = checkpoint @@ -111,9 +107,7 @@ def __initialize_private_attributes(self, kwargs: Dict) -> None: Call private_attrs_callable function set attributes to self.__private_attrs. """ - self.__private_attrs = self.__private_attrs_callable( - **kwargs - ) + self.__private_attrs = self.__private_attrs_callable(**kwargs) def __set_attributes_to_clone(self, clone: Any) -> None: """ @@ -123,7 +117,9 @@ def __set_attributes_to_clone(self, clone: Any) -> None: for name, attr in self.__private_attrs.items(): setattr(clone, name, attr) - def __delete_agg_attrs_from_clone(self, clone: Any, replace_str: str = None) -> None: + def __delete_agg_attrs_from_clone( + self, clone: Any, replace_str: str = None + ) -> None: """ Remove aggregator private attributes from FLSpec clone before transition from Aggregator step to collaborator steps. @@ -133,7 +129,9 @@ def __delete_agg_attrs_from_clone(self, clone: Any, replace_str: str = None) -> if len(self.__private_attrs) > 0: for attr_name in self.__private_attrs: if hasattr(clone, attr_name): - self.__private_attrs.update({attr_name: getattr(clone, attr_name)}) + self.__private_attrs.update( + {attr_name: getattr(clone, attr_name)} + ) if replace_str: setattr(clone, attr_name, replace_str) else: @@ -186,27 +184,37 @@ def run_flow(self) -> None: len_connected_collabs = len(self.connected_collaborators) if len_connected_collabs < len_sel_collabs: # Waiting for collaborators to connect. - self.logger.info("Waiting for " - + f"{len_connected_collabs}/{len_sel_collabs}" - + " collaborators to connect...") + self.logger.info( + "Waiting for " + + f"{len_connected_collabs}/{len_sel_collabs}" + + " collaborators to connect..." + ) elif self.tasks_sent_to_collaborators != len_sel_collabs: - self.logger.info("Waiting for " - + f"{self.tasks_sent_to_collaborators}/{len_sel_collabs}" - + " to make requests for tasks...") + self.logger.info( + "Waiting for " + + f"{self.tasks_sent_to_collaborators}/{len_sel_collabs}" + + " to make requests for tasks..." + ) else: # Waiting for selected collaborators to send the results. - self.logger.info("Waiting for " - + f"{self.collaborators_counter}/{len_sel_collabs}" - + " collaborators to send results...") + self.logger.info( + "Waiting for " + + f"{self.collaborators_counter}/{len_sel_collabs}" + + " collaborators to send results..." + ) time.sleep(Aggregator._get_sleep_time()) self.collaborator_task_results.clear() f_name = self.next_step if hasattr(self, "instance_snapshot"): - self.flow.restore_instance_snapshot(self.flow, list(self.instance_snapshot)) + self.flow.restore_instance_snapshot( + self.flow, list(self.instance_snapshot) + ) delattr(self, "instance_snapshot") - def call_checkpoint(self, ctx: Any, f: Callable, stream_buffer: bytes = None) -> None: + def call_checkpoint( + self, ctx: Any, f: Callable, stream_buffer: bytes = None + ) -> None: """ Perform checkpoint task. @@ -222,9 +230,7 @@ def call_checkpoint(self, ctx: Any, f: Callable, stream_buffer: bytes = None) -> None """ if self.checkpoint: - from openfl.experimental.interface import ( - FLSpec, - ) + from openfl.experimental.interface import FLSpec # Check if arguments are pickled, if yes then unpickle if not isinstance(ctx, FLSpec): @@ -235,7 +241,9 @@ def call_checkpoint(self, ctx: Any, f: Callable, stream_buffer: bytes = None) -> f = pickle.loads(f) if isinstance(stream_buffer, bytes): # Set stream buffer as function parameter - setattr(f.__func__, "_stream_buffer", pickle.loads(stream_buffer)) + setattr( + f.__func__, "_stream_buffer", pickle.loads(stream_buffer) + ) checkpoint(ctx, f) @@ -266,22 +274,38 @@ def get_tasks(self, collaborator_name: str) -> Tuple: # If it is time to then inform the collaborator if self.time_to_quit: self.logger.info( - f"Sending signal to collaborator {collaborator_name} to shutdown...") + f"Sending signal to collaborator {collaborator_name} to shutdown..." + ) # FIXME: 0, and "" instead of None is just for protobuf compatibility. # Cleaner solution? - return 0, "", None, Aggregator._get_sleep_time(), self.time_to_quit + return ( + 0, + "", + None, + Aggregator._get_sleep_time(), + self.time_to_quit, + ) # If not time to quit then sleep for 10 seconds time.sleep(Aggregator._get_sleep_time()) # Get collaborator step, and clone for requesting collaborator next_step, clone = self.__collaborator_tasks_queue[ - collaborator_name].get() + collaborator_name + ].get() self.tasks_sent_to_collaborators += 1 - self.logger.info("Sending tasks to collaborator" - + f" {collaborator_name} for round {self.current_round}...") - return self.current_round, next_step, pickle.dumps(clone), 0, self.time_to_quit + self.logger.info( + "Sending tasks to collaborator" + + f" {collaborator_name} for round {self.current_round}..." + ) + return ( + self.current_round, + next_step, + pickle.dumps(clone), + 0, + self.time_to_quit, + ) def do_task(self, f_name: str) -> Any: """ @@ -306,7 +330,9 @@ def do_task(self, f_name: str) -> Any: if f.__name__ == "end": f() # Take the checkpoint of "end" step - self.__delete_agg_attrs_from_clone(self.flow, "Private attributes: Not Available.") + self.__delete_agg_attrs_from_clone( + self.flow, "Private attributes: Not Available." + ) self.call_checkpoint(self.flow, f) self.__set_attributes_to_clone(self.flow) # Check if all rounds of external loop is executed @@ -342,7 +368,9 @@ def do_task(self, f_name: str) -> Any: # clones are arguments f(*selected_clones) - self.__delete_agg_attrs_from_clone(self.flow, "Private attributes: Not Available.") + self.__delete_agg_attrs_from_clone( + self.flow, "Private attributes: Not Available." + ) # Take the checkpoint of executed step self.call_checkpoint(self.flow, f) self.__set_attributes_to_clone(self.flow) @@ -361,7 +389,9 @@ def do_task(self, f_name: str) -> Any: temp = self.flow.execute_task_args[3:] self.clones_dict, self.instance_snapshot, self.kwargs = temp - self.selected_collaborators = getattr(self.flow, self.kwargs["foreach"]) + self.selected_collaborators = getattr( + self.flow, self.kwargs["foreach"] + ) else: self.kwargs = self.flow.execute_task_args[3] @@ -373,8 +403,13 @@ def do_task(self, f_name: str) -> Any: return f_name if f_name != "end" else None - def send_task_results(self, collab_name: str, round_number: int, next_step: str, - clone_bytes: bytes) -> None: + def send_task_results( + self, + collab_name: str, + round_number: int, + next_step: str, + clone_bytes: bytes, + ) -> None: """ After collaborator execution, collaborator will call this function via gRPc to send next function. @@ -412,11 +447,14 @@ def send_task_results(self, collab_name: str, round_number: int, next_step: str, # Set the event to inform aggregator to resume the flow execution self.collaborator_task_results.set() # Empty tasks_sent_to_collaborators list for next time. - if self.tasks_sent_to_collaborators == len(self.selected_collaborators): + if self.tasks_sent_to_collaborators == len( + self.selected_collaborators + ): self.tasks_sent_to_collaborators = 0 - def valid_collaborator_cn_and_id(self, cert_common_name: str, - collaborator_common_name: str) -> bool: + def valid_collaborator_cn_and_id( + self, cert_common_name: str, collaborator_common_name: str + ) -> bool: """ Determine if the collaborator certificate and ID are valid for this federation. @@ -433,13 +471,17 @@ def valid_collaborator_cn_and_id(self, cert_common_name: str, # FIXME: "" instead of None is just for protobuf compatibility. # Cleaner solution? if self.single_col_cert_common_name == "": - return (cert_common_name == collaborator_common_name - and collaborator_common_name in self.authorized_cols) + return ( + cert_common_name == collaborator_common_name + and collaborator_common_name in self.authorized_cols + ) # otherwise, common_name must be in whitelist and # collaborator_common_name must be in authorized_cols else: - return (cert_common_name == self.single_col_cert_common_name - and collaborator_common_name in self.authorized_cols) + return ( + cert_common_name == self.single_col_cert_common_name + and collaborator_common_name in self.authorized_cols + ) def all_quit_jobs_sent(self) -> bool: """Assert all quit jobs are sent to collaborators.""" diff --git a/openfl/experimental/component/collaborator/__init__.py b/openfl/experimental/component/collaborator/__init__.py index 29ce6da9d3..b9089e0eca 100644 --- a/openfl/experimental/component/collaborator/__init__.py +++ b/openfl/experimental/component/collaborator/__init__.py @@ -1,8 +1,6 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """openfl.experimental.component.collaborator package.""" -from .collaborator import Collaborator - -__all__ = ["Collaborator",] +# FIXME: Too much recursion. +from openfl.experimental.component.collaborator.collaborator import Collaborator diff --git a/openfl/experimental/component/collaborator/collaborator.py b/openfl/experimental/component/collaborator/collaborator.py index 65e6210ca4..8dc1b3c8a8 100644 --- a/openfl/experimental/component/collaborator/collaborator.py +++ b/openfl/experimental/component/collaborator/collaborator.py @@ -1,13 +1,10 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Experimental Collaborator module.""" -import time import pickle - -from typing import Any, Callable -from typing import Dict, Tuple +import time from logging import getLogger +from typing import Any, Callable, Dict, Tuple class Collaborator: @@ -28,14 +25,17 @@ class Collaborator: Note: \* - Plan setting. """ - def __init__(self, - collaborator_name: str, - aggregator_uuid: str, - federation_uuid: str, - client: Any, - private_attributes_callable: Any = None, - private_attributes_kwargs: Dict = {}, - **kwargs) -> None: + + def __init__( + self, + collaborator_name: str, + aggregator_uuid: str, + federation_uuid: str, + client: Any, + private_attributes_callable: Any = None, + private_attributes_kwargs: Dict = {}, + **kwargs, + ) -> None: self.name = collaborator_name self.aggregator_uuid = aggregator_uuid @@ -63,9 +63,7 @@ def __initialize_private_attributes(self, kwrags: Dict) -> None: Returns: None """ - self.__private_attrs = self.__private_attrs_callable( - **kwrags - ) + self.__private_attrs = self.__private_attrs_callable(**kwrags) def __set_attributes_to_clone(self, clone: Any) -> None: """ @@ -82,7 +80,9 @@ def __set_attributes_to_clone(self, clone: Any) -> None: for name, attr in self.__private_attrs.items(): setattr(clone, name, attr) - def __delete_agg_attrs_from_clone(self, clone: Any, replace_str: str = None) -> None: + def __delete_agg_attrs_from_clone( + self, clone: Any, replace_str: str = None + ) -> None: """ Remove aggregator private attributes from FLSpec clone before transition from Aggregator step to collaborator steps @@ -99,13 +99,17 @@ def __delete_agg_attrs_from_clone(self, clone: Any, replace_str: str = None) -> if len(self.__private_attrs) > 0: for attr_name in self.__private_attrs: if hasattr(clone, attr_name): - self.__private_attrs.update({attr_name: getattr(clone, attr_name)}) + self.__private_attrs.update( + {attr_name: getattr(clone, attr_name)} + ) if replace_str: setattr(clone, attr_name, replace_str) else: delattr(clone, attr_name) - def call_checkpoint(self, ctx: Any, f: Callable, stream_buffer: Any) -> None: + def call_checkpoint( + self, ctx: Any, f: Callable, stream_buffer: Any + ) -> None: """ Call checkpoint gRPC. @@ -119,7 +123,9 @@ def call_checkpoint(self, ctx: Any, f: Callable, stream_buffer: Any) -> None: """ self.client.call_checkpoint( self.name, - pickle.dumps(ctx), pickle.dumps(f), pickle.dumps(stream_buffer) + pickle.dumps(ctx), + pickle.dumps(f), + pickle.dumps(stream_buffer), ) def run(self) -> None: @@ -157,11 +163,12 @@ def send_task_results(self, next_step: str, clone: Any) -> None: Returns: None """ - self.logger.info(f"Round {self.round_number}," - f" collaborator {self.name} is sending results...") + self.logger.info( + f"Round {self.round_number}," + f" collaborator {self.name} is sending results..." + ) self.client.send_task_results( - self.name, self.round_number, - next_step, pickle.dumps(clone) + self.name, self.round_number, next_step, pickle.dumps(clone) ) def get_tasks(self) -> Tuple: @@ -179,7 +186,9 @@ def get_tasks(self) -> Tuple: """ self.logger.info("Waiting for tasks...") temp = self.client.get_tasks(self.name) - self.round_number, next_step, clone_bytes, sleep_time, time_to_quit = temp + self.round_number, next_step, clone_bytes, sleep_time, time_to_quit = ( + temp + ) return next_step, pickle.loads(clone_bytes), sleep_time, time_to_quit @@ -203,7 +212,9 @@ def do_task(self, f_name: str, ctx: Any) -> Tuple: f = getattr(ctx, f_name) f() # Checkpoint the function - self.__delete_agg_attrs_from_clone(ctx, "Private attributes: Not Available.") + self.__delete_agg_attrs_from_clone( + ctx, "Private attributes: Not Available." + ) self.call_checkpoint(ctx, f, f._stream_buffer) self.__set_attributes_to_clone(ctx) diff --git a/openfl/experimental/federated/__init__.py b/openfl/experimental/federated/__init__.py index 77a79d67f1..fb82b790ea 100644 --- a/openfl/experimental/federated/__init__.py +++ b/openfl/experimental/federated/__init__.py @@ -1,8 +1,6 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """openfl.experimental.federated package.""" -from .plan import Plan # NOQA - -__all__ = ["Plan"] +# FIXME: Recursion! +from openfl.experimental.federated.plan import Plan diff --git a/openfl/experimental/federated/plan/__init__.py b/openfl/experimental/federated/plan/__init__.py index eb1f085d43..9fdecde62c 100644 --- a/openfl/experimental/federated/plan/__init__.py +++ b/openfl/experimental/federated/plan/__init__.py @@ -1,8 +1,6 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Experimental Plan package.""" -from .plan import Plan - -__all__ = ['Plan',] +# FIXME: Too much recursion in namespace +from openfl.experimental.federated.plan.plan import Plan diff --git a/openfl/experimental/federated/plan/plan.py b/openfl/experimental/federated/plan/plan.py index 4fcda43703..4b378ef02e 100644 --- a/openfl/experimental/federated/plan/plan.py +++ b/openfl/experimental/federated/plan/plan.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Plan module.""" import inspect from hashlib import sha384 @@ -9,13 +8,13 @@ from os.path import splitext from pathlib import Path -from yaml import dump -from yaml import safe_load -from yaml import SafeDumper +from yaml import SafeDumper, dump, safe_load from openfl.experimental.interface.cli.cli_helper import WORKSPACE -from openfl.experimental.transport import AggregatorGRPCClient -from openfl.experimental.transport import AggregatorGRPCServer +from openfl.experimental.transport import ( + AggregatorGRPCClient, + AggregatorGRPCServer, +) from openfl.utilities.utils import getfqdn_env SETTINGS = "settings" @@ -43,6 +42,7 @@ def dump(yaml_path, config, freeze=False): """Dump the plan config to YAML file.""" class NoAliasDumper(SafeDumper): + def ignore_aliases(self, data): return True @@ -113,14 +113,18 @@ def parse( if SETTINGS in defaults: # override defaults with section settings - defaults[SETTINGS].update(plan.config[section][SETTINGS]) + defaults[SETTINGS].update( + plan.config[section][SETTINGS] + ) plan.config[section][SETTINGS] = defaults[SETTINGS] defaults.update(plan.config[section]) plan.config[section] = defaults - plan.authorized_cols = Plan.load(cols_config_path).get("collaborators", []) + plan.authorized_cols = Plan.load(cols_config_path).get( + "collaborators", [] + ) if resolve: plan.resolve() @@ -177,8 +181,12 @@ def build(template, settings, **override): f"from [red]{module_path}[/] Module.", extra={"markup": True}, ) - Plan.logger.debug(f"Settings [red]πŸ‘†[/] {settings}", extra={"markup": True}) - Plan.logger.debug(f"Override [red]πŸ‘†[/] {override}", extra={"markup": True}) + Plan.logger.debug( + f"Settings [red]πŸ‘†[/] {settings}", extra={"markup": True} + ) + Plan.logger.debug( + f"Override [red]πŸ‘†[/] {override}", extra={"markup": True} + ) settings.update(**override) module = import_module(module_path) @@ -233,7 +241,8 @@ def hash(self): # NOQA """Generate hash for this instance.""" self.hash_ = sha384(dump(self.config).encode("utf-8")) Plan.logger.info( - f"FL-Plan hash is [blue]{self.hash_.hexdigest()}[/]", extra={"markup": True} + f"FL-Plan hash is [blue]{self.hash_.hexdigest()}[/]", + extra={"markup": True}, ) return self.hash_.hexdigest() @@ -243,7 +252,9 @@ def resolve(self): self.federation_uuid = f"{self.name}_{self.hash[:8]}" self.aggregator_uuid = f"aggregator_{self.federation_uuid}" - self.rounds_to_train = self.config["aggregator"][SETTINGS]["rounds_to_train"] + self.rounds_to_train = self.config["aggregator"][SETTINGS][ + "rounds_to_train" + ] if self.config["network"][SETTINGS]["agg_addr"] == AUTO: self.config["network"][SETTINGS]["agg_addr"] = getfqdn_env() @@ -256,18 +267,20 @@ def resolve(self): def get_aggregator(self): """Get federation aggregator.""" defaults = self.config.get( - "aggregator", { - TEMPLATE: "openfl.experimental.Aggregator", - SETTINGS: {} - } + "aggregator", + {TEMPLATE: "openfl.experimental.Aggregator", SETTINGS: {}}, ) defaults[SETTINGS]["aggregator_uuid"] = self.aggregator_uuid defaults[SETTINGS]["federation_uuid"] = self.federation_uuid defaults[SETTINGS]["authorized_cols"] = self.authorized_cols - private_attrs_callable, private_attrs_kwargs = self.get_private_attr("aggregator") - defaults[SETTINGS]["private_attributes_callable"] = private_attrs_callable + private_attrs_callable, private_attrs_kwargs = self.get_private_attr( + "aggregator" + ) + defaults[SETTINGS][ + "private_attributes_callable" + ] = private_attrs_callable defaults[SETTINGS]["private_attributes_kwargs"] = private_attrs_kwargs defaults[SETTINGS]["flow"] = self.get_flow() @@ -302,19 +315,21 @@ def get_collaborator( ): """Get collaborator.""" defaults = self.config.get( - "collaborator", { - TEMPLATE: "openfl.experimental.Collaborator", - SETTINGS: {} - } + "collaborator", + {TEMPLATE: "openfl.experimental.Collaborator", SETTINGS: {}}, ) defaults[SETTINGS]["collaborator_name"] = collaborator_name defaults[SETTINGS]["aggregator_uuid"] = self.aggregator_uuid defaults[SETTINGS]["federation_uuid"] = self.federation_uuid - private_attrs_callable, private_attrs_kwargs = self.get_private_attr(collaborator_name) + private_attrs_callable, private_attrs_kwargs = self.get_private_attr( + collaborator_name + ) - defaults[SETTINGS]["private_attributes_callable"] = private_attrs_callable + defaults[SETTINGS][ + "private_attributes_callable" + ] = private_attrs_callable defaults[SETTINGS]["private_attributes_kwargs"] = private_attrs_kwargs if client is not None: @@ -367,7 +382,11 @@ def get_client( return self.client_ def get_server( - self, root_certificate=None, private_key=None, certificate=None, **kwargs + self, + root_certificate=None, + private_key=None, + certificate=None, + **kwargs, ): """Get gRPC server of the aggregator instance.""" common_name = self.config["network"][SETTINGS]["agg_addr"].lower() @@ -396,10 +415,8 @@ def get_server( def get_flow(self): """instantiates federated flow object""" defaults = self.config.get( - "federated_flow", { - TEMPLATE: self.config["federated_flow"]["template"], - SETTINGS: {} - }, + "federated_flow", + {TEMPLATE: self.config["federated_flow"]["template"], SETTINGS: {}}, ) defaults = self.import_kwargs_modules(defaults) @@ -407,6 +424,7 @@ def get_flow(self): return self.flow_ def import_kwargs_modules(self, defaults): + def import_nested_settings(settings): for key, value in settings.items(): if isinstance(value, dict): @@ -419,8 +437,8 @@ def import_nested_settings(settings): if import_module(module_path): module = import_module(module_path) value_defaults_data = { - 'template': value, - 'settings': settings.get('settings', {}), + "template": value, + "settings": settings.get("settings", {}), } attr = getattr(module, class_name) @@ -440,9 +458,10 @@ def get_private_attr(self, private_attr_name=None): private_attrs_kwargs = {} import os - from openfl.experimental.federated.plan import Plan from pathlib import Path + from openfl.experimental.federated.plan import Plan + data_yaml = "plan/data.yaml" if os.path.exists(data_yaml) and os.path.isfile(data_yaml): @@ -450,7 +469,9 @@ def get_private_attr(self, private_attr_name=None): if d.get(private_attr_name, None): private_attrs_callable = { - "template": d.get(private_attr_name)["callable_func"]["template"] + "template": d.get(private_attr_name)["callable_func"][ + "template" + ] } private_attrs_kwargs = self.import_kwargs_modules( @@ -458,7 +479,9 @@ def get_private_attr(self, private_attr_name=None): )["settings"] if isinstance(private_attrs_callable, dict): - private_attrs_callable = Plan.import_(**private_attrs_callable) + private_attrs_callable = Plan.import_( + **private_attrs_callable + ) elif not callable(private_attrs_callable): raise TypeError( f"private_attrs_callable should be callable object " diff --git a/openfl/experimental/interface/__init__.py b/openfl/experimental/interface/__init__.py index fc03bd8459..14d076f473 100644 --- a/openfl/experimental/interface/__init__.py +++ b/openfl/experimental/interface/__init__.py @@ -1,9 +1,6 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """openfl.experimental.interface package.""" -from .fl_spec import FLSpec -from .participants import Aggregator, Collaborator - -__all__ = ["FLSpec", "Aggregator", "Collaborator"] +from openfl.experimental.interface.fl_spec import FLSpec +from openfl.experimental.interface.participants import Aggregator, Collaborator diff --git a/openfl/experimental/interface/cli/aggregator.py b/openfl/experimental/interface/cli/aggregator.py index c189cda215..ec307e361a 100644 --- a/openfl/experimental/interface/cli/aggregator.py +++ b/openfl/experimental/interface/cli/aggregator.py @@ -6,12 +6,8 @@ import threading from logging import getLogger -from click import echo -from click import group -from click import option -from click import pass_context from click import Path as ClickPath -from click import style +from click import echo, group, option, pass_context, style from openfl.utilities import click_types from openfl.utilities.path_check import is_directory_traversal @@ -24,19 +20,34 @@ @pass_context def aggregator(context): """Manage Federated Learning Aggregator.""" - context.obj['group'] = 'aggregator' - - -@aggregator.command(name='start') -@option('-p', '--plan', required=False, - help='Federated learning plan [plan/plan.yaml]', - default='plan/plan.yaml', - type=ClickPath(exists=True)) -@option('-c', '--authorized_cols', required=False, - help='Authorized collaborator list [plan/cols.yaml]', - default='plan/cols.yaml', type=ClickPath(exists=True)) -@option('-s', '--secure', required=False, - help='Enable Intel SGX Enclave', is_flag=True, default=False) + context.obj["group"] = "aggregator" + + +@aggregator.command(name="start") +@option( + "-p", + "--plan", + required=False, + help="Federated learning plan [plan/plan.yaml]", + default="plan/plan.yaml", + type=ClickPath(exists=True), +) +@option( + "-c", + "--authorized_cols", + required=False, + help="Authorized collaborator list [plan/cols.yaml]", + default="plan/cols.yaml", + type=ClickPath(exists=True), +) +@option( + "-s", + "--secure", + required=False, + help="Enable Intel SGX Enclave", + is_flag=True, + default=False, +) def start_(plan, authorized_cols, secure): """Start the aggregator service.""" import os @@ -45,30 +56,39 @@ def start_(plan, authorized_cols, secure): from openfl.experimental.federated.plan import Plan if is_directory_traversal(plan): - echo('Federated learning plan path is out of the openfl workspace scope.') + echo( + "Federated learning plan path is out of the openfl workspace scope." + ) sys.exit(1) if is_directory_traversal(authorized_cols): - echo('Authorized collaborator list file path is out of the openfl workspace scope.') + echo( + "Authorized collaborator list file path is out of the openfl workspace scope." + ) sys.exit(1) - plan = Plan.parse(plan_config_path=Path(plan).absolute(), - cols_config_path=Path(authorized_cols).absolute()) + plan = Plan.parse( + plan_config_path=Path(plan).absolute(), + cols_config_path=Path(authorized_cols).absolute(), + ) - if not os.path.exists('plan/data.yaml'): + if not os.path.exists("plan/data.yaml"): logger.warning( - 'Aggregator private attributes are set to None as plan/data.yaml not found' - + ' in workspace.') + "Aggregator private attributes are set to None as plan/data.yaml not found" + + " in workspace." + ) else: import yaml from yaml.loader import SafeLoader - with open('plan/data.yaml', 'r') as f: + + with open("plan/data.yaml", "r") as f: data = yaml.load(f, Loader=SafeLoader) if data.get("aggregator", None) is None: logger.warning( - 'Aggregator private attributes are set to None as no aggregator' - + ' attributes found in plan/data.yaml.') + "Aggregator private attributes are set to None as no aggregator" + + " attributes found in plan/data.yaml." + ) - logger.info('🧿 Starting the Aggregator Service.') + logger.info("🧿 Starting the Aggregator Service.") agg_server = plan.get_server() agg_server.is_server_started = False @@ -81,55 +101,65 @@ def start_(plan, authorized_cols, secure): break -@aggregator.command(name='generate-cert-request') -@option('--fqdn', required=False, type=click_types.FQDN, - help=f'The fully qualified domain name of' - f' aggregator node [{getfqdn_env()}]', - default=getfqdn_env()) +@aggregator.command(name="generate-cert-request") +@option( + "--fqdn", + required=False, + type=click_types.FQDN, + help=f"The fully qualified domain name of" + f" aggregator node [{getfqdn_env()}]", + default=getfqdn_env(), +) def _generate_cert_request(fqdn): generate_cert_request(fqdn) def generate_cert_request(fqdn): """Create aggregator certificate key pair.""" + from openfl.cryptography.io import get_csr_hash, write_crt, write_key from openfl.cryptography.participant import generate_csr - from openfl.cryptography.io import write_crt - from openfl.cryptography.io import write_key - from openfl.cryptography.io import get_csr_hash from openfl.experimental.interface.cli.cli_helper import CERT_DIR if fqdn is None: fqdn = getfqdn_env() - common_name = f'{fqdn}'.lower() - subject_alternative_name = f'DNS:{common_name}' - file_name = f'agg_{common_name}' + common_name = f"{fqdn}".lower() + subject_alternative_name = f"DNS:{common_name}" + file_name = f"agg_{common_name}" - echo(f'Creating AGGREGATOR certificate key pair with following settings: ' - f'CN={style(common_name, fg="red")},' - f' SAN={style(subject_alternative_name, fg="red")}') + echo( + f"Creating AGGREGATOR certificate key pair with following settings: " + f'CN={style(common_name, fg="red")},' + f' SAN={style(subject_alternative_name, fg="red")}' + ) server_private_key, server_csr = generate_csr(common_name, server=True) - (CERT_DIR / 'server').mkdir(parents=True, exist_ok=True) + (CERT_DIR / "server").mkdir(parents=True, exist_ok=True) - echo(' Writing AGGREGATOR certificate key pair to: ' + style( - f'{CERT_DIR}/server', fg='green')) + echo( + " Writing AGGREGATOR certificate key pair to: " + + style(f"{CERT_DIR}/server", fg="green") + ) # Print csr hash before writing csr to disk csr_hash = get_csr_hash(server_csr) - echo('The CSR Hash ' + style(f'{csr_hash}', fg='red')) + echo("The CSR Hash " + style(f"{csr_hash}", fg="red")) # Write aggregator csr and key to disk - write_crt(server_csr, CERT_DIR / 'server' / f'{file_name}.csr') - write_key(server_private_key, CERT_DIR / 'server' / f'{file_name}.key') - - -@aggregator.command(name='certify') -@option('-n', '--fqdn', type=click_types.FQDN, - help=f'The fully qualified domain name of aggregator node [{getfqdn_env()}]', - default=getfqdn_env()) -@option('-s', '--silent', help='Do not prompt', is_flag=True) + write_crt(server_csr, CERT_DIR / "server" / f"{file_name}.csr") + write_key(server_private_key, CERT_DIR / "server" / f"{file_name}.key") + + +@aggregator.command(name="certify") +@option( + "-n", + "--fqdn", + type=click_types.FQDN, + help=f"The fully qualified domain name of aggregator node [{getfqdn_env()}]", + default=getfqdn_env(), +) +@option("-s", "--silent", help="Do not prompt", is_flag=True) def _certify(fqdn, silent): certify(fqdn, silent) @@ -141,70 +171,85 @@ def certify(fqdn, silent): from click import confirm from openfl.cryptography.ca import sign_certificate - from openfl.cryptography.io import read_crt - from openfl.cryptography.io import read_csr - from openfl.cryptography.io import read_key - from openfl.cryptography.io import write_crt + from openfl.cryptography.io import read_crt, read_csr, read_key, write_crt from openfl.experimental.interface.cli.cli_helper import CERT_DIR if fqdn is None: fqdn = getfqdn_env() - common_name = f'{fqdn}'.lower() - file_name = f'agg_{common_name}' - cert_name = f'server/{file_name}' - signing_key_path = 'ca/signing-ca/private/signing-ca.key' - signing_crt_path = 'ca/signing-ca.crt' + common_name = f"{fqdn}".lower() + file_name = f"agg_{common_name}" + cert_name = f"server/{file_name}" + signing_key_path = "ca/signing-ca/private/signing-ca.key" + signing_crt_path = "ca/signing-ca.crt" # Load CSR - csr_path_absolute_path = Path(CERT_DIR / f'{cert_name}.csr').absolute() + csr_path_absolute_path = Path(CERT_DIR / f"{cert_name}.csr").absolute() if not csr_path_absolute_path.exists(): - echo(style('Aggregator certificate signing request not found.', fg='red') - + ' Please run `fx aggregator generate-cert-request`' - ' to generate the certificate request.') + echo( + style("Aggregator certificate signing request not found.", fg="red") + + " Please run `fx aggregator generate-cert-request`" + " to generate the certificate request." + ) csr, csr_hash = read_csr(csr_path_absolute_path) # Load private signing key - private_sign_key_absolute_path = Path(CERT_DIR / signing_key_path).absolute() + private_sign_key_absolute_path = Path( + CERT_DIR / signing_key_path + ).absolute() if not private_sign_key_absolute_path.exists(): - echo(style('Signing key not found.', fg='red') - + ' Please run `fx workspace certify`' - ' to initialize the local certificate authority.') + echo( + style("Signing key not found.", fg="red") + + " Please run `fx workspace certify`" + " to initialize the local certificate authority." + ) signing_key = read_key(private_sign_key_absolute_path) # Load signing cert signing_crt_absolute_path = Path(CERT_DIR / signing_crt_path).absolute() if not signing_crt_absolute_path.exists(): - echo(style('Signing certificate not found.', fg='red') - + ' Please run `fx workspace certify`' - ' to initialize the local certificate authority.') + echo( + style("Signing certificate not found.", fg="red") + + " Please run `fx workspace certify`" + " to initialize the local certificate authority." + ) signing_crt = read_crt(signing_crt_absolute_path) - echo('The CSR Hash for file ' - + style(f'{cert_name}.csr', fg='green') - + ' = ' - + style(f'{csr_hash}', fg='red')) + echo( + "The CSR Hash for file " + + style(f"{cert_name}.csr", fg="green") + + " = " + + style(f"{csr_hash}", fg="red") + ) - crt_path_absolute_path = Path(CERT_DIR / f'{cert_name}.crt').absolute() + crt_path_absolute_path = Path(CERT_DIR / f"{cert_name}.crt").absolute() if silent: - echo(' Warning: manual check of certificate hashes is bypassed in silent mode.') - echo(' Signing AGGREGATOR certificate') - signed_agg_cert = sign_certificate(csr, signing_key, signing_crt.subject) + echo( + " Warning: manual check of certificate hashes is bypassed in silent mode." + ) + echo(" Signing AGGREGATOR certificate") + signed_agg_cert = sign_certificate( + csr, signing_key, signing_crt.subject + ) write_crt(signed_agg_cert, crt_path_absolute_path) else: - echo('Make sure the two hashes above are the same.') - if confirm('Do you want to sign this certificate?'): + echo("Make sure the two hashes above are the same.") + if confirm("Do you want to sign this certificate?"): - echo(' Signing AGGREGATOR certificate') - signed_agg_cert = sign_certificate(csr, signing_key, signing_crt.subject) + echo(" Signing AGGREGATOR certificate") + signed_agg_cert = sign_certificate( + csr, signing_key, signing_crt.subject + ) write_crt(signed_agg_cert, crt_path_absolute_path) else: - echo(style('Not signing certificate.', fg='red') - + ' Please check with this AGGREGATOR to get the correct' - ' certificate for this federation.') + echo( + style("Not signing certificate.", fg="red") + + " Please check with this AGGREGATOR to get the correct" + " certificate for this federation." + ) diff --git a/openfl/experimental/interface/cli/cli_helper.py b/openfl/experimental/interface/cli/cli_helper.py index e552b17209..d8ddb2bd48 100644 --- a/openfl/experimental/interface/cli/cli_helper.py +++ b/openfl/experimental/interface/cli/cli_helper.py @@ -3,23 +3,20 @@ """Module with auxiliary CLI helper functions.""" from itertools import islice -from os import environ -from os import stat +from os import environ, stat from pathlib import Path from sys import argv -from click import echo -from click import style -from yaml import FullLoader -from yaml import load +from click import echo, style +from yaml import FullLoader, load FX = argv[0] SITEPACKS = Path(__file__).parent.parent.parent.parent.parent -WORKSPACE = SITEPACKS / 'openfl-workspace' / 'experimental' -TUTORIALS = SITEPACKS / 'openfl-tutorials' -OPENFL_USERDIR = Path.home() / '.openfl' -CERT_DIR = Path('cert').absolute() +WORKSPACE = SITEPACKS / "openfl-workspace" / "experimental" +TUTORIALS = SITEPACKS / "openfl-tutorials" +OPENFL_USERDIR = Path.home() / ".openfl" +CERT_DIR = Path("cert").absolute() def pretty(o): @@ -27,40 +24,43 @@ def pretty(o): m = max(map(len, o.keys())) for k, v in o.items(): - echo(style(f'{k:<{m}} : ', fg='blue') + style(f'{v}', fg='cyan')) + echo(style(f"{k:<{m}} : ", fg="blue") + style(f"{v}", fg="cyan")) def tree(path): """Print current directory file tree.""" - echo(f'+ {path}') + echo(f"+ {path}") - for path in sorted(path.rglob('*')): + for path in sorted(path.rglob("*")): depth = len(path.relative_to(path).parts) - space = ' ' * depth + space = " " * depth if path.is_file(): - echo(f'{space}f {path.name}') + echo(f"{space}f {path.name}") else: - echo(f'{space}d {path.name}') + echo(f"{space}d {path.name}") -def print_tree(dir_path: Path, level: int = -1, - limit_to_directories: bool = False, - length_limit: int = 1000): +def print_tree( + dir_path: Path, + level: int = -1, + limit_to_directories: bool = False, + length_limit: int = 1000, +): """Given a directory Path object print a visual tree structure.""" - space = ' ' - branch = 'β”‚ ' - tee = 'β”œβ”€β”€ ' - last = '└── ' + space = " " + branch = "β”‚ " + tee = "β”œβ”€β”€ " + last = "└── " - echo('\nNew experimental workspace directory structure:') + echo("\nNew experimental workspace directory structure:") dir_path = Path(dir_path) # accept string coerceable to Path files = 0 directories = 0 - def inner(dir_path: Path, prefix: str = '', level=-1): + def inner(dir_path: Path, prefix: str = "", level=-1): nonlocal files, directories if not level: return # 0, stop iterating @@ -74,8 +74,9 @@ def inner(dir_path: Path, prefix: str = '', level=-1): yield prefix + pointer + path.name directories += 1 extension = branch if pointer == tee else space - yield from inner(path, prefix=prefix + extension, - level=level - 1) + yield from inner( + path, prefix=prefix + extension, level=level - 1 + ) elif not limit_to_directories: yield prefix + pointer + path.name files += 1 @@ -85,12 +86,18 @@ def inner(dir_path: Path, prefix: str = '', level=-1): for line in islice(iterator, length_limit): echo(line) if next(iterator, None): - echo(f'... length_limit, {length_limit}, reached, counted:') - echo(f'\n{directories} directories' + (f', {files} files' if files else '')) - - -def copytree(src, dst, symlinks=False, ignore=None, - ignore_dangling_symlinks=False, dirs_exist_ok=False): + echo(f"... length_limit, {length_limit}, reached, counted:") + echo(f"\n{directories} directories" + (f", {files} files" if files else "")) + + +def copytree( + src, + dst, + symlinks=False, + ignore=None, + ignore_dangling_symlinks=False, + dirs_exist_ok=False, +): """From Python 3.8 'shutil' which include 'dirs_exist_ok' option.""" import os import shutil @@ -109,7 +116,9 @@ def _copytree(): os.makedirs(dst, exist_ok=dirs_exist_ok) errors = [] - use_srcentry = copy_function is shutil.copy2 or copy_function is shutil.copy + use_srcentry = ( + copy_function is shutil.copy2 or copy_function is shutil.copy + ) for srcentry in entries: if srcentry.name in ignored_names: @@ -119,7 +128,7 @@ def _copytree(): srcobj = srcentry if use_srcentry else srcname try: is_symlink = srcentry.is_symlink() - if is_symlink and os.name == 'nt': + if is_symlink and os.name == "nt": lstat = srcentry.stat(follow_symlinks=False) if lstat.st_reparse_tag == stat.IO_REPARSE_TAG_MOUNT_POINT: is_symlink = False @@ -127,20 +136,33 @@ def _copytree(): linkto = os.readlink(srcname) if symlinks: os.symlink(linkto, dstname) - shutil.copystat(srcobj, dstname, - follow_symlinks=not symlinks) + shutil.copystat( + srcobj, dstname, follow_symlinks=not symlinks + ) else: - if (not os.path.exists(linkto) - and ignore_dangling_symlinks): + if ( + not os.path.exists(linkto) + and ignore_dangling_symlinks + ): continue if srcentry.is_dir(): - copytree(srcobj, dstname, symlinks, ignore, - dirs_exist_ok=dirs_exist_ok) + copytree( + srcobj, + dstname, + symlinks, + ignore, + dirs_exist_ok=dirs_exist_ok, + ) else: copy_function(srcobj, dstname) elif srcentry.is_dir(): - copytree(srcobj, dstname, symlinks, ignore, - dirs_exist_ok=dirs_exist_ok) + copytree( + srcobj, + dstname, + symlinks, + ignore, + dirs_exist_ok=dirs_exist_ok, + ) else: copy_function(srcobj, dstname) except OSError as why: @@ -150,7 +172,7 @@ def _copytree(): try: shutil.copystat(src, dst) except OSError as why: - if getattr(why, 'winerror', None) is None: + if getattr(why, "winerror", None) is None: errors.append((src, dst, str(why))) if errors: raise Exception(errors) @@ -162,21 +184,21 @@ def _copytree(): def get_workspace_parameter(name): """Get a parameter from the workspace config file (.workspace).""" # Update the .workspace file to show the current workspace plan - workspace_file = '.workspace' + workspace_file = ".workspace" - with open(workspace_file, 'r', encoding='utf-8') as f: + with open(workspace_file, "r", encoding="utf-8") as f: doc = load(f, Loader=FullLoader) if not doc: # YAML is not correctly formatted doc = {} # Create empty dictionary if name not in doc.keys() or not doc[name]: # List doesn't exist - return '' + return "" else: return doc[name] -def check_varenv(env: str = '', args: dict = None): +def check_varenv(env: str = "", args: dict = None): """Update "args" (dictionary) with if env has a defined value in the host.""" if args is None: args = {} @@ -187,23 +209,23 @@ def check_varenv(env: str = '', args: dict = None): return args -def get_fx_path(curr_path=''): +def get_fx_path(curr_path=""): """Return the absolute path to fx binary.""" - import re import os + import re - match = re.search('lib', curr_path) + match = re.search("lib", curr_path) idx = match.end() path_prefix = curr_path[0:idx] - bin_path = re.sub('lib', 'bin', path_prefix) - fx_path = os.path.join(bin_path, 'fx') + bin_path = re.sub("lib", "bin", path_prefix) + fx_path = os.path.join(bin_path, "fx") return fx_path def remove_line_from_file(pkg, filename): """Remove line that contains `pkg` from the `filename` file.""" - with open(filename, 'r+', encoding='utf-8') as f: + with open(filename, "r+", encoding="utf-8") as f: d = f.readlines() f.seek(0) for i in d: @@ -214,7 +236,7 @@ def remove_line_from_file(pkg, filename): def replace_line_in_file(line, line_num_to_replace, filename): """Replace line at `line_num_to_replace` with `line`.""" - with open(filename, 'r+', encoding='utf-8') as f: + with open(filename, "r+", encoding="utf-8") as f: d = f.readlines() f.seek(0) for idx, i in enumerate(d): diff --git a/openfl/experimental/interface/cli/collaborator.py b/openfl/experimental/interface/cli/collaborator.py index e31de19a88..c5c6d06ac3 100644 --- a/openfl/experimental/interface/cli/collaborator.py +++ b/openfl/experimental/interface/cli/collaborator.py @@ -2,20 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 """Collaborator module.""" -import sys import os +import sys from logging import getLogger -from click import echo -from click import group -from click import option -from click import pass_context from click import Path as ClickPath -from click import style +from click import echo, group, option, pass_context, style from openfl.utilities.path_check import is_directory_traversal - logger = getLogger(__name__) @@ -42,12 +37,12 @@ def collaborator(context): help="The certified common name of the collaborator", ) @option( - '-s', - '--secure', + "-s", + "--secure", required=False, - help='Enable Intel SGX Enclave', + help="Enable Intel SGX Enclave", is_flag=True, - default=False + default=False, ) def start_(plan, collaborator_name, secure, data_config="plan/data.yaml"): """Start a collaborator service.""" @@ -56,7 +51,9 @@ def start_(plan, collaborator_name, secure, data_config="plan/data.yaml"): from openfl.experimental.federated import Plan if plan and is_directory_traversal(plan): - echo("Federated learning plan path is out of the openfl workspace scope.") + echo( + "Federated learning plan path is out of the openfl workspace scope." + ) sys.exit(1) if data_config and is_directory_traversal(data_config): echo( @@ -70,20 +67,24 @@ def start_(plan, collaborator_name, secure, data_config="plan/data.yaml"): ) if not os.path.exists(data_config): - logger.warning('Collaborator private attributes are set to None as' - f' {data_config} not found in workspace.') + logger.warning( + "Collaborator private attributes are set to None as" + f" {data_config} not found in workspace." + ) else: import yaml from yaml.loader import SafeLoader + collaborator_name = collaborator_name.lower() - with open(data_config, 'r') as f: + with open(data_config, "r") as f: data = yaml.load(f, Loader=SafeLoader) if data.get(collaborator_name, None) is None: logger.warning( - f'Collaborator private attributes are set to None as no attributes' - f' for {collaborator_name} found in {data_config}.') + f"Collaborator private attributes are set to None as no attributes" + f" for {collaborator_name} found in {data_config}." + ) - logger.info('🧿 Starting the Collaborator Service.') + logger.info("🧿 Starting the Collaborator Service.") plan.get_collaborator(collaborator_name).run() @@ -113,10 +114,8 @@ def generate_cert_request(collaborator_name, silent, skip_package): Then create a package with the CSR to send for signing. """ + from openfl.cryptography.io import get_csr_hash, write_crt, write_key from openfl.cryptography.participant import generate_csr - from openfl.cryptography.io import write_crt - from openfl.cryptography.io import write_key - from openfl.cryptography.io import get_csr_hash from openfl.experimental.interface.cli.cli_helper import CERT_DIR common_name = f"{collaborator_name}".lower() @@ -147,14 +146,11 @@ def generate_cert_request(collaborator_name, silent, skip_package): write_key(client_private_key, CERT_DIR / "client" / f"{file_name}.key") if not skip_package: - from shutil import copytree - from shutil import ignore_patterns - from shutil import make_archive - from tempfile import mkdtemp - from os.path import basename - from os.path import join - from os import remove from glob import glob + from os import remove + from os.path import basename, join + from shutil import copytree, ignore_patterns, make_archive + from tempfile import mkdtemp from openfl.utilities.utils import rmtree @@ -178,7 +174,8 @@ def generate_cert_request(collaborator_name, silent, skip_package): rmtree(tmp_dir) echo( - f"Archive {archive_file_name} with certificate signing" f" request created" + f"Archive {archive_file_name} with certificate signing" + f" request created" ) echo( "This file should be sent to the certificate authority" @@ -200,11 +197,10 @@ def register_collaborator(file_name): """ from os.path import isfile - from yaml import dump - from yaml import FullLoader - from yaml import load from pathlib import Path + from yaml import FullLoader, dump, load + col_name = find_certificate_name(file_name) cols_file = Path("plan/cols.yaml").absolute() @@ -272,22 +268,17 @@ def certify_(collaborator_name, silent, request_pkg, import_): def certify(collaborator_name, silent, request_pkg=None, import_=False): """Sign/certify collaborator certificate key pair.""" - from click import confirm - from pathlib import Path - from shutil import copy - from shutil import make_archive - from shutil import unpack_archive from glob import glob - from os.path import basename - from os.path import join - from os.path import splitext from os import remove + from os.path import basename, join, splitext + from pathlib import Path + from shutil import copy, make_archive, unpack_archive from tempfile import mkdtemp + + from click import confirm + from openfl.cryptography.ca import sign_certificate - from openfl.cryptography.io import read_crt - from openfl.cryptography.io import read_csr - from openfl.cryptography.io import read_key - from openfl.cryptography.io import write_crt + from openfl.cryptography.io import read_crt, read_csr, read_key, write_crt from openfl.experimental.interface.cli.cli_helper import CERT_DIR from openfl.utilities.utils import rmtree @@ -318,7 +309,10 @@ def certify(collaborator_name, silent, request_pkg=None, import_=False): # Load CSR if not Path(f"{cert_name}.csr").exists(): echo( - style("Collaborator certificate signing request not found.", fg="red") + style( + "Collaborator certificate signing request not found.", + fg="red", + ) + " Please run `fx collaborator generate-cert-request`" " to generate the certificate request." ) @@ -357,7 +351,9 @@ def certify(collaborator_name, silent, request_pkg=None, import_=False): echo( " Warning: manual check of certificate hashes is bypassed in silent mode." ) - signed_col_cert = sign_certificate(csr, signing_key, signing_crt.subject) + signed_col_cert = sign_certificate( + csr, signing_key, signing_crt.subject + ) write_crt(signed_col_cert, f"{cert_name}.crt") register_collaborator(CERT_DIR / "client" / f"{file_name}.crt") diff --git a/openfl/experimental/interface/cli/experimental.py b/openfl/experimental/interface/cli/experimental.py index e97dbabeb9..f6ed41e4d3 100644 --- a/openfl/experimental/interface/cli/experimental.py +++ b/openfl/experimental/interface/cli/experimental.py @@ -5,8 +5,7 @@ import os from pathlib import Path -from click import group -from click import pass_context +from click import group, pass_context @group() @@ -19,7 +18,8 @@ def experimental(context): @experimental.command(name="deactivate") def deactivate(): """Deactivate experimental environment.""" - settings = Path("~").expanduser().joinpath( - ".openfl", "experimental").resolve() + settings = ( + Path("~").expanduser().joinpath(".openfl", "experimental").resolve() + ) os.remove(settings) diff --git a/openfl/experimental/interface/cli/plan.py b/openfl/experimental/interface/cli/plan.py index 3026b182c9..f2ae1ede2c 100644 --- a/openfl/experimental/interface/cli/plan.py +++ b/openfl/experimental/interface/cli/plan.py @@ -5,11 +5,8 @@ import sys from logging import getLogger -from click import echo -from click import group -from click import option -from click import pass_context from click import Path as ClickPath +from click import echo, group, option, pass_context from openfl.utilities.path_check import is_directory_traversal @@ -20,24 +17,43 @@ @pass_context def plan(context): """Manage Federated Learning Plans.""" - context.obj['group'] = 'plan' + context.obj["group"] = "plan" @plan.command() @pass_context -@option('-p', '--plan_config', required=False, - help='Federated learning plan [plan/plan.yaml]', - default='plan/plan.yaml', type=ClickPath(exists=True)) -@option('-c', '--cols_config', required=False, - help='Authorized collaborator list [plan/cols.yaml]', - default='plan/cols.yaml', type=ClickPath(exists=True)) -@option('-d', '--data_config', required=False, - help='The data set/shard configuration file [plan/data.yaml]', - default='plan/data.yaml') -@option('-a', '--aggregator_address', required=False, - help='The FQDN of the federation agregator') -def initialize(context, plan_config, cols_config, data_config, - aggregator_address): +@option( + "-p", + "--plan_config", + required=False, + help="Federated learning plan [plan/plan.yaml]", + default="plan/plan.yaml", + type=ClickPath(exists=True), +) +@option( + "-c", + "--cols_config", + required=False, + help="Authorized collaborator list [plan/cols.yaml]", + default="plan/cols.yaml", + type=ClickPath(exists=True), +) +@option( + "-d", + "--data_config", + required=False, + help="The data set/shard configuration file [plan/data.yaml]", + default="plan/data.yaml", +) +@option( + "-a", + "--aggregator_address", + required=False, + help="The FQDN of the federation agregator", +) +def initialize( + context, plan_config, cols_config, data_config, aggregator_address +): """ Initialize Data Science plan. @@ -51,34 +67,42 @@ def initialize(context, plan_config, cols_config, data_config, for p in [plan_config, cols_config, data_config]: if is_directory_traversal(p): - echo(f'{p} is out of the openfl workspace scope.') + echo(f"{p} is out of the openfl workspace scope.") sys.exit(1) plan_config = Path(plan_config).absolute() cols_config = Path(cols_config).absolute() data_config = Path(data_config).absolute() - plan = Plan.parse(plan_config_path=plan_config, - cols_config_path=cols_config, - data_config_path=data_config) + plan = Plan.parse( + plan_config_path=plan_config, + cols_config_path=cols_config, + data_config_path=data_config, + ) plan_origin = Plan.parse(plan_config, resolve=False).config - if (plan_origin['network']['settings']['agg_addr'] == 'auto' - or aggregator_address): - plan_origin['network']['settings']['agg_addr'] = aggregator_address or getfqdn_env() + if ( + plan_origin["network"]["settings"]["agg_addr"] == "auto" + or aggregator_address + ): + plan_origin["network"]["settings"]["agg_addr"] = ( + aggregator_address or getfqdn_env() + ) - logger.warn(f'Patching Aggregator Addr in Plan' - f" πŸ † {plan_origin['network']['settings']['agg_addr']}") + logger.warn( + f"Patching Aggregator Addr in Plan" + f" πŸ † {plan_origin['network']['settings']['agg_addr']}" + ) Plan.dump(plan_config, plan_origin) plan.config = plan_origin # Record that plan with this hash has been initialized - if 'plans' not in context.obj: - context.obj['plans'] = [] - context.obj['plans'].append(f'{plan_config.stem}_{plan.hash[:8]}') + if "plans" not in context.obj: + context.obj["plans"] = [] + context.obj["plans"].append(f"{plan_config.stem}_{plan.hash[:8]}") logger.info(f"{context.obj['plans']}") diff --git a/openfl/experimental/interface/cli/workspace.py b/openfl/experimental/interface/cli/workspace.py index f76391e0f8..7d21f2566b 100644 --- a/openfl/experimental/interface/cli/workspace.py +++ b/openfl/experimental/interface/cli/workspace.py @@ -2,20 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 """Workspace module.""" -import sys import os +import sys +from logging import getLogger from pathlib import Path from typing import Tuple -from logging import getLogger from click import Choice -from click import confirm -from click import echo -from click import style -from click import group -from click import option -from click import pass_context from click import Path as ClickPath +from click import confirm, echo, group, option, pass_context, style from openfl.utilities.path_check import is_directory_traversal from openfl.utilities.workspace import dump_requirements_file @@ -27,7 +22,7 @@ @pass_context def workspace(context): """Manage Experimental Federated Learning Workspaces.""" - context.obj['group'] = 'workspace' + context.obj["group"] = "workspace" def create_dirs(prefix): @@ -36,30 +31,35 @@ def create_dirs(prefix): from openfl.experimental.interface.cli.cli_helper import WORKSPACE - echo('Creating Workspace Directories') + echo("Creating Workspace Directories") - (prefix / 'cert').mkdir(parents=True, exist_ok=True) # certifications - (prefix / 'data').mkdir(parents=True, exist_ok=True) # training data - (prefix / 'logs').mkdir(parents=True, exist_ok=True) # training logs - (prefix / 'save').mkdir(parents=True, exist_ok=True) # model weight saves / initialization - (prefix / 'src').mkdir(parents=True, exist_ok=True) # model code + (prefix / "cert").mkdir(parents=True, exist_ok=True) # certifications + (prefix / "data").mkdir(parents=True, exist_ok=True) # training data + (prefix / "logs").mkdir(parents=True, exist_ok=True) # training logs + (prefix / "save").mkdir( + parents=True, exist_ok=True + ) # model weight saves / initialization + (prefix / "src").mkdir(parents=True, exist_ok=True) # model code - copyfile(WORKSPACE / 'workspace' / '.workspace', prefix / '.workspace') + copyfile(WORKSPACE / "workspace" / ".workspace", prefix / ".workspace") def create_temp(prefix, template): """Create workspace templates.""" from shutil import ignore_patterns - from openfl.experimental.interface.cli.cli_helper import copytree - from openfl.experimental.interface.cli.cli_helper import WORKSPACE + from openfl.experimental.interface.cli.cli_helper import WORKSPACE, copytree - echo('Creating Workspace Templates') + echo("Creating Workspace Templates") # Use the specified template if it's a Path, otherwise use WORKSPACE/template source = template if isinstance(template, Path) else WORKSPACE / template - copytree(src=source, dst=prefix, dirs_exist_ok=True, - ignore=ignore_patterns('__pycache__')) # from template workspace + copytree( + src=source, + dst=prefix, + dirs_exist_ok=True, + ignore=ignore_patterns("__pycache__"), + ) # from template workspace apply_template_plan(prefix, template) @@ -67,70 +67,86 @@ def get_templates(): """Grab the default templates from the distribution.""" from openfl.experimental.interface.cli.cli_helper import WORKSPACE - return [d.name for d in WORKSPACE.glob('*') if d.is_dir() - and d.name not in ['__pycache__', 'workspace']] - - -@workspace.command(name='create') -@option('--prefix', required=True, - help='Workspace name or path', type=ClickPath()) -@option('--custom_template', required=False, - help='Path to custom template', type=ClickPath(exists=True)) -@option('--notebook', required=False, - help='Path to jupyter notebook', type=ClickPath(exists=True)) -@option('--template_output_dir', required=False, - help='Destination directory to save your Jupyter Notebook workspace.', - type=ClickPath(exists=False, file_okay=False, dir_okay=True)) -@option('--template', required=False, type=Choice(get_templates())) + return [ + d.name + for d in WORKSPACE.glob("*") + if d.is_dir() and d.name not in ["__pycache__", "workspace"] + ] + + +@workspace.command(name="create") +@option( + "--prefix", required=True, help="Workspace name or path", type=ClickPath() +) +@option( + "--custom_template", + required=False, + help="Path to custom template", + type=ClickPath(exists=True), +) +@option( + "--notebook", + required=False, + help="Path to jupyter notebook", + type=ClickPath(exists=True), +) +@option( + "--template_output_dir", + required=False, + help="Destination directory to save your Jupyter Notebook workspace.", + type=ClickPath(exists=False, file_okay=False, dir_okay=True), +) +@option("--template", required=False, type=Choice(get_templates())) def create_(prefix, custom_template, template, notebook, template_output_dir): """Create the experimental workspace.""" if is_directory_traversal(prefix): - echo('Workspace name or path is out of the openfl workspace scope.') + echo("Workspace name or path is out of the openfl workspace scope.") sys.exit(1) if custom_template and template and notebook: raise ValueError( - 'Please provide either `template`, `custom_template` or ' - + '`notebook`. Not all are necessary' + "Please provide either `template`, `custom_template` or " + + "`notebook`. Not all are necessary" ) elif ( - (custom_template and template) - or (template and notebook) - or (custom_template and notebook)): + (custom_template and template) + or (template and notebook) + or (custom_template and notebook) + ): raise ValueError( - 'Please provide only one of the following options: ' - + '`template`, `custom_template`, or `notebook`.' + "Please provide only one of the following options: " + + "`template`, `custom_template`, or `notebook`." ) if not (custom_template or template or notebook): raise ValueError( - 'Please provide one of the following options: ' - + '`template`, `custom_template`, or `notebook`.' + "Please provide one of the following options: " + + "`template`, `custom_template`, or `notebook`." ) if notebook: if not template_output_dir: raise ValueError( - 'Please provide output_workspace which is Destination directory to ' - + 'save your Jupyter Notebook workspace.' + "Please provide output_workspace which is Destination directory to " + + "save your Jupyter Notebook workspace." ) from openfl.experimental.workspace_export import WorkspaceExport WorkspaceExport.export( - notebook_path=notebook, output_workspace=template_output_dir, + notebook_path=notebook, + output_workspace=template_output_dir, ) create(prefix, template_output_dir) logger.warning( - 'The user should review the generated workspace for completeness ' - + 'before proceeding') + "The user should review the generated workspace for completeness " + + "before proceeding" + ) else: template = ( - Path(custom_template).resolve() - if custom_template - else template + Path(custom_template).resolve() if custom_template else template ) create(prefix, template) @@ -143,7 +159,7 @@ def create(prefix, template): from openfl.experimental.interface.cli.cli_helper import ( OPENFL_USERDIR, - print_tree + print_tree, ) if not OPENFL_USERDIR.exists(): @@ -154,133 +170,166 @@ def create(prefix, template): create_dirs(prefix) create_temp(prefix, template) - requirements_filename = 'requirements.txt' + requirements_filename = "requirements.txt" - if not os.path.exists(f'{str(prefix)}/plan/data.yaml'): - echo(style('Participant private attributes shall be set to None as plan/data.yaml' - + ' was not found in the workspace.', fg='yellow')) + if not os.path.exists(f"{str(prefix)}/plan/data.yaml"): + echo( + style( + "Participant private attributes shall be set to None as plan/data.yaml" + + " was not found in the workspace.", + fg="yellow", + ) + ) - if isfile(f'{str(prefix)}/{requirements_filename}'): - check_call([ - executable, '-m', 'pip', 'install', '-r', - f'{prefix}/requirements.txt'], shell=False) - echo(f'Successfully installed packages from {prefix}/requirements.txt.') + if isfile(f"{str(prefix)}/{requirements_filename}"): + check_call( + [ + executable, + "-m", + "pip", + "install", + "-r", + f"{prefix}/requirements.txt", + ], + shell=False, + ) + echo(f"Successfully installed packages from {prefix}/requirements.txt.") else: - echo('No additional requirements for workspace defined. Skipping...') + echo("No additional requirements for workspace defined. Skipping...") prefix_hash = _get_dir_hash(str(prefix.absolute())) - with open(OPENFL_USERDIR / f'requirements.{prefix_hash}.txt', 'w', encoding='utf-8') as f: - check_call([executable, '-m', 'pip', 'freeze'], shell=False, stdout=f) + with open( + OPENFL_USERDIR / f"requirements.{prefix_hash}.txt", + "w", + encoding="utf-8", + ) as f: + check_call([executable, "-m", "pip", "freeze"], shell=False, stdout=f) print_tree(prefix, level=3) -@workspace.command(name='export') -@option('-o', '--pip-install-options', required=False, - type=str, multiple=True, default=tuple, - help='Options for remote pip install. ' - 'You may pass several options in quotation marks alongside with arguments, ' - 'e.g. -o "--find-links source.site"') +@workspace.command(name="export") +@option( + "-o", + "--pip-install-options", + required=False, + type=str, + multiple=True, + default=tuple, + help="Options for remote pip install. " + "You may pass several options in quotation marks alongside with arguments, " + 'e.g. -o "--find-links source.site"', +) def export_(pip_install_options: Tuple[str]): """Export federated learning workspace.""" - from os import getcwd - from os import makedirs - from os.path import basename - from os.path import join - from shutil import copy2 - from shutil import copytree - from shutil import ignore_patterns - from shutil import make_archive + from os import getcwd, makedirs + from os.path import basename, join + from shutil import copy2, copytree, ignore_patterns, make_archive from tempfile import mkdtemp from plan import freeze_plan + from openfl.experimental.interface.cli.cli_helper import WORKSPACE from openfl.utilities.utils import rmtree - echo(style('This command will archive the contents of \'plan\' and \'src\' directory, user' - + ' should review that these does not contain any information which is private and' - + ' not to be shared.', fg='yellow')) + echo( + style( + "This command will archive the contents of 'plan' and 'src' directory, user" + + " should review that these does not contain any information which is private and" + + " not to be shared.", + fg="yellow", + ) + ) - plan_file = Path('plan/plan.yaml').absolute() + plan_file = Path("plan/plan.yaml").absolute() try: freeze_plan(plan_file) except FileNotFoundError: echo(f'Plan file "{plan_file}" not found. No freeze performed.') # Dump requirements.txt - dump_requirements_file(prefixes=pip_install_options, keep_original_prefixes=True) + dump_requirements_file( + prefixes=pip_install_options, keep_original_prefixes=True + ) - archive_type = 'zip' + archive_type = "zip" archive_name = basename(getcwd()) - archive_file_name = archive_name + '.' + archive_type + archive_file_name = archive_name + "." + archive_type # Aggregator workspace - tmp_dir = join(mkdtemp(), 'openfl', archive_name) + tmp_dir = join(mkdtemp(), "openfl", archive_name) ignore = ignore_patterns( - '__pycache__', '*.crt', '*.key', '*.csr', '*.srl', '*.pem', '*.pbuf') + "__pycache__", "*.crt", "*.key", "*.csr", "*.srl", "*.pem", "*.pbuf" + ) # We only export the minimum required files to set up a collaborator - makedirs(f'{tmp_dir}/save', exist_ok=True) - makedirs(f'{tmp_dir}/logs', exist_ok=True) - makedirs(f'{tmp_dir}/data', exist_ok=True) - copytree('./src', f'{tmp_dir}/src', ignore=ignore) # code - copytree('./plan', f'{tmp_dir}/plan', ignore=ignore) # plan - copy2('./requirements.txt', f'{tmp_dir}/requirements.txt') # requirements + makedirs(f"{tmp_dir}/save", exist_ok=True) + makedirs(f"{tmp_dir}/logs", exist_ok=True) + makedirs(f"{tmp_dir}/data", exist_ok=True) + copytree("./src", f"{tmp_dir}/src", ignore=ignore) # code + copytree("./plan", f"{tmp_dir}/plan", ignore=ignore) # plan + copy2("./requirements.txt", f"{tmp_dir}/requirements.txt") # requirements try: - copy2('.workspace', tmp_dir) # .workspace + copy2(".workspace", tmp_dir) # .workspace except FileNotFoundError: - echo('\'.workspace\' file not found.') - if confirm('Create a default \'.workspace\' file?'): - copy2(WORKSPACE / 'workspace' / '.workspace', tmp_dir) + echo("'.workspace' file not found.") + if confirm("Create a default '.workspace' file?"): + copy2(WORKSPACE / "workspace" / ".workspace", tmp_dir) else: - echo('To proceed, you must have a \'.workspace\' ' - 'file in the current directory.') + echo( + "To proceed, you must have a '.workspace' " + "file in the current directory." + ) raise # Create Zip archive of directory - echo('\n πŸ—œοΈ Preparing workspace distribution zip file') + echo("\n πŸ—œοΈ Preparing workspace distribution zip file") make_archive(archive_name, archive_type, tmp_dir) rmtree(tmp_dir) - echo(f'\n βœ”οΈ Workspace exported to archive: {archive_file_name}') + echo(f"\n βœ”οΈ Workspace exported to archive: {archive_file_name}") -@workspace.command(name='import') -@option('--archive', required=True, - help='Zip file containing workspace to import', - type=ClickPath(exists=True)) +@workspace.command(name="import") +@option( + "--archive", + required=True, + help="Zip file containing workspace to import", + type=ClickPath(exists=True), +) def import_(archive): """Import federated learning workspace.""" from os import chdir - from os.path import basename - from os.path import isfile + from os.path import basename, isfile from shutil import unpack_archive from subprocess import check_call from sys import executable archive = Path(archive).absolute() - dir_path = basename(archive).split('.')[0] + dir_path = basename(archive).split(".")[0] unpack_archive(archive, extract_dir=dir_path) chdir(dir_path) - requirements_filename = 'requirements.txt' + requirements_filename = "requirements.txt" if isfile(requirements_filename): - check_call([ - executable, '-m', 'pip', 'install', '--upgrade', 'pip'], - shell=False) - check_call([ - executable, '-m', 'pip', 'install', '-r', requirements_filename], - shell=False) + check_call( + [executable, "-m", "pip", "install", "--upgrade", "pip"], + shell=False, + ) + check_call( + [executable, "-m", "pip", "install", "-r", requirements_filename], + shell=False, + ) else: - echo('No ' + requirements_filename + ' file found.') + echo("No " + requirements_filename + " file found.") - echo(f'Workspace {archive} has been imported.') - echo('You may need to copy your PKI certificates to join the federation.') + echo(f"Workspace {archive} has been imported.") + echo("You may need to copy your PKI certificates to join the federation.") -@workspace.command(name='certify') +@workspace.command(name="certify") def certify_(): """Create certificate authority for federation.""" certify() @@ -290,122 +339,155 @@ def certify(): """Create certificate authority for federation.""" from cryptography.hazmat.primitives import serialization - from openfl.cryptography.ca import generate_root_cert - from openfl.cryptography.ca import generate_signing_csr - from openfl.cryptography.ca import sign_certificate + from openfl.cryptography.ca import ( + generate_root_cert, + generate_signing_csr, + sign_certificate, + ) from openfl.experimental.interface.cli.cli_helper import CERT_DIR - echo('Setting Up Certificate Authority...\n') + echo("Setting Up Certificate Authority...\n") - echo('1. Create Root CA') - echo('1.1 Create Directories') + echo("1. Create Root CA") + echo("1.1 Create Directories") - (CERT_DIR / 'ca/root-ca/private').mkdir( - parents=True, exist_ok=True, mode=0o700) - (CERT_DIR / 'ca/root-ca/db').mkdir(parents=True, exist_ok=True) + (CERT_DIR / "ca/root-ca/private").mkdir( + parents=True, exist_ok=True, mode=0o700 + ) + (CERT_DIR / "ca/root-ca/db").mkdir(parents=True, exist_ok=True) - echo('1.2 Create Database') + echo("1.2 Create Database") - with open(CERT_DIR / 'ca/root-ca/db/root-ca.db', 'w', encoding='utf-8') as f: + with open( + CERT_DIR / "ca/root-ca/db/root-ca.db", "w", encoding="utf-8" + ) as f: pass # write empty file - with open(CERT_DIR / 'ca/root-ca/db/root-ca.db.attr', 'w', encoding='utf-8') as f: + with open( + CERT_DIR / "ca/root-ca/db/root-ca.db.attr", "w", encoding="utf-8" + ) as f: pass # write empty file - with open(CERT_DIR / 'ca/root-ca/db/root-ca.crt.srl', 'w', encoding='utf-8') as f: - f.write('01') # write file with '01' - with open(CERT_DIR / 'ca/root-ca/db/root-ca.crl.srl', 'w', encoding='utf-8') as f: - f.write('01') # write file with '01' + with open( + CERT_DIR / "ca/root-ca/db/root-ca.crt.srl", "w", encoding="utf-8" + ) as f: + f.write("01") # write file with '01' + with open( + CERT_DIR / "ca/root-ca/db/root-ca.crl.srl", "w", encoding="utf-8" + ) as f: + f.write("01") # write file with '01' - echo('1.3 Create CA Request and Certificate') + echo("1.3 Create CA Request and Certificate") - root_crt_path = 'ca/root-ca.crt' - root_key_path = 'ca/root-ca/private/root-ca.key' + root_crt_path = "ca/root-ca.crt" + root_key_path = "ca/root-ca/private/root-ca.key" root_private_key, root_cert = generate_root_cert() # Write root CA certificate to disk - with open(CERT_DIR / root_crt_path, 'wb') as f: - f.write(root_cert.public_bytes( - encoding=serialization.Encoding.PEM, - )) + with open(CERT_DIR / root_crt_path, "wb") as f: + f.write( + root_cert.public_bytes( + encoding=serialization.Encoding.PEM, + ) + ) - with open(CERT_DIR / root_key_path, 'wb') as f: - f.write(root_private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption() - )) + with open(CERT_DIR / root_key_path, "wb") as f: + f.write( + root_private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + ) - echo('2. Create Signing Certificate') - echo('2.1 Create Directories') + echo("2. Create Signing Certificate") + echo("2.1 Create Directories") - (CERT_DIR / 'ca/signing-ca/private').mkdir( - parents=True, exist_ok=True, mode=0o700) - (CERT_DIR / 'ca/signing-ca/db').mkdir(parents=True, exist_ok=True) + (CERT_DIR / "ca/signing-ca/private").mkdir( + parents=True, exist_ok=True, mode=0o700 + ) + (CERT_DIR / "ca/signing-ca/db").mkdir(parents=True, exist_ok=True) - echo('2.2 Create Database') + echo("2.2 Create Database") - with open(CERT_DIR / 'ca/signing-ca/db/signing-ca.db', 'w', encoding='utf-8') as f: + with open( + CERT_DIR / "ca/signing-ca/db/signing-ca.db", "w", encoding="utf-8" + ) as f: pass # write empty file - with open(CERT_DIR / 'ca/signing-ca/db/signing-ca.db.attr', 'w', encoding='utf-8') as f: + with open( + CERT_DIR / "ca/signing-ca/db/signing-ca.db.attr", "w", encoding="utf-8" + ) as f: pass # write empty file - with open(CERT_DIR / 'ca/signing-ca/db/signing-ca.crt.srl', 'w', encoding='utf-8') as f: - f.write('01') # write file with '01' - with open(CERT_DIR / 'ca/signing-ca/db/signing-ca.crl.srl', 'w', encoding='utf-8') as f: - f.write('01') # write file with '01' + with open( + CERT_DIR / "ca/signing-ca/db/signing-ca.crt.srl", "w", encoding="utf-8" + ) as f: + f.write("01") # write file with '01' + with open( + CERT_DIR / "ca/signing-ca/db/signing-ca.crl.srl", "w", encoding="utf-8" + ) as f: + f.write("01") # write file with '01' - echo('2.3 Create Signing Certificate CSR') + echo("2.3 Create Signing Certificate CSR") - signing_csr_path = 'ca/signing-ca.csr' - signing_crt_path = 'ca/signing-ca.crt' - signing_key_path = 'ca/signing-ca/private/signing-ca.key' + signing_csr_path = "ca/signing-ca.csr" + signing_crt_path = "ca/signing-ca.crt" + signing_key_path = "ca/signing-ca/private/signing-ca.key" signing_private_key, signing_csr = generate_signing_csr() # Write Signing CA CSR to disk - with open(CERT_DIR / signing_csr_path, 'wb') as f: - f.write(signing_csr.public_bytes( - encoding=serialization.Encoding.PEM, - )) + with open(CERT_DIR / signing_csr_path, "wb") as f: + f.write( + signing_csr.public_bytes( + encoding=serialization.Encoding.PEM, + ) + ) - with open(CERT_DIR / signing_key_path, 'wb') as f: - f.write(signing_private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption() - )) + with open(CERT_DIR / signing_key_path, "wb") as f: + f.write( + signing_private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + ) - echo('2.4 Sign Signing Certificate CSR') + echo("2.4 Sign Signing Certificate CSR") - signing_cert = sign_certificate(signing_csr, root_private_key, root_cert.subject, ca=True) + signing_cert = sign_certificate( + signing_csr, root_private_key, root_cert.subject, ca=True + ) - with open(CERT_DIR / signing_crt_path, 'wb') as f: - f.write(signing_cert.public_bytes( - encoding=serialization.Encoding.PEM, - )) + with open(CERT_DIR / signing_crt_path, "wb") as f: + f.write( + signing_cert.public_bytes( + encoding=serialization.Encoding.PEM, + ) + ) - echo('3 Create Certificate Chain') + echo("3 Create Certificate Chain") # create certificate chain file by combining root-ca and signing-ca - with open(CERT_DIR / 'cert_chain.crt', 'w', encoding='utf-8') as d: - with open(CERT_DIR / 'ca/root-ca.crt', encoding='utf-8') as s: + with open(CERT_DIR / "cert_chain.crt", "w", encoding="utf-8") as d: + with open(CERT_DIR / "ca/root-ca.crt", encoding="utf-8") as s: d.write(s.read()) - with open(CERT_DIR / 'ca/signing-ca.crt') as s: + with open(CERT_DIR / "ca/signing-ca.crt") as s: d.write(s.read()) - echo('\nDone.') + echo("\nDone.") + # FIXME: Function is not in use def _get_requirements_dict(txtfile): - with open(txtfile, 'r', encoding='utf-8') as snapshot: + with open(txtfile, "r", encoding="utf-8") as snapshot: snapshot_dict = {} for line in snapshot: try: # 'pip freeze' generates requirements with exact versions - k, v = line.split('==') + k, v = line.split("==") snapshot_dict[k] = v except ValueError: snapshot_dict[line] = None @@ -414,8 +496,9 @@ def _get_requirements_dict(txtfile): def _get_dir_hash(path): from hashlib import sha256 + hash_ = sha256() - hash_.update(path.encode('utf-8')) + hash_.update(path.encode("utf-8")) hash_ = hash_.hexdigest() return hash_ @@ -432,6 +515,6 @@ def apply_template_plan(prefix, template): # Use the specified template if it's a Path, otherwise use WORKSPACE/template source = template if isinstance(template, Path) else WORKSPACE / template - template_plan = Plan.parse(source / 'plan' / 'plan.yaml') + template_plan = Plan.parse(source / "plan" / "plan.yaml") - Plan.dump(prefix / 'plan' / 'plan.yaml', template_plan.config) + Plan.dump(prefix / "plan" / "plan.yaml", template_plan.config) diff --git a/openfl/experimental/interface/fl_spec.py b/openfl/experimental/interface/fl_spec.py index 771e471d97..74ea7415af 100644 --- a/openfl/experimental/interface/fl_spec.py +++ b/openfl/experimental/interface/fl_spec.py @@ -1,24 +1,24 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """openfl.experimental.interface.flspec module.""" from __future__ import annotations import inspect from copy import deepcopy -from typing import Type, List, Callable +from typing import Callable, List, Type + +from openfl.experimental.runtime import Runtime from openfl.experimental.utilities import ( MetaflowInterface, SerializationError, - generate_artifacts, aggregator_to_collaborator, + checkpoint, collaborator_to_aggregator, - should_transfer, filter_attributes, - checkpoint + generate_artifacts, + should_transfer, ) -from openfl.experimental.runtime import Runtime class FLSpec: @@ -119,7 +119,9 @@ def _capture_instance_snapshot(self, kwargs): return_objs.append(backup) return return_objs - def _is_at_transition_point(self, f: Callable, parent_func: Callable) -> bool: + def _is_at_transition_point( + self, f: Callable, parent_func: Callable + ) -> bool: """ Has the collaborator finished its current sequence? @@ -130,12 +132,16 @@ def _is_at_transition_point(self, f: Callable, parent_func: Callable) -> bool: if parent_func.__name__ in self._foreach_methods: self._foreach_methods.append(f.__name__) if should_transfer(f, parent_func): - print(f"Should transfer from {parent_func.__name__} to {f.__name__}") + print( + f"Should transfer from {parent_func.__name__} to {f.__name__}" + ) self.execute_next = f.__name__ return True return False - def _display_transition_logs(self, f: Callable, parent_func: Callable) -> None: + def _display_transition_logs( + self, f: Callable, parent_func: Callable + ) -> None: """ Prints aggregator to collaborators or collaborators to aggregator state transition logs @@ -159,9 +165,9 @@ def filter_exclude_include(self, f, **kwargs): for col in selected_collaborators: clone = FLSpec._clones[col] clone.input = col - if ("exclude" in kwargs and hasattr(clone, kwargs["exclude"][0])) or ( - "include" in kwargs and hasattr(clone, kwargs["include"][0]) - ): + if ( + "exclude" in kwargs and hasattr(clone, kwargs["exclude"][0]) + ) or ("include" in kwargs and hasattr(clone, kwargs["include"][0])): filter_attributes(clone, f, **kwargs) artifacts_iter, _ = generate_artifacts(ctx=self) for name, attr in artifacts_iter(): @@ -184,7 +190,7 @@ def get_clones(self, kwargs): """ FLSpec._reset_clones() FLSpec._create_clones(self, self.runtime.collaborators) - selected_collaborators = self.__getattribute__(kwargs['foreach']) + selected_collaborators = self.__getattribute__(kwargs["foreach"]) for col in selected_collaborators: clone = FLSpec._clones[col] @@ -226,9 +232,15 @@ def next(self, f, **kwargs): if "foreach" in kwargs: self.filter_exclude_include(f, **kwargs) - # if "foreach" in kwargs: - self.execute_task_args = (self, f, parent_func, FLSpec._clones, - agg_to_collab_ss, kwargs) + # if "foreach" in kwargs: + self.execute_task_args = ( + self, + f, + parent_func, + FLSpec._clones, + agg_to_collab_ss, + kwargs, + ) else: self.execute_task_args = (self, f, parent_func, kwargs) diff --git a/openfl/experimental/interface/participants.py b/openfl/experimental/interface/participants.py index 84847fb6fd..4f48ae61fd 100644 --- a/openfl/experimental/interface/participants.py +++ b/openfl/experimental/interface/participants.py @@ -1,13 +1,12 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """openfl.experimental.interface.participants module.""" -from typing import Dict, Any -from typing import Callable, Optional +from typing import Any, Callable, Dict, Optional class Participant: + def __init__(self, name: str = ""): self.private_attributes = {} self._name = name.lower() @@ -98,7 +97,9 @@ def initialize_private_attributes(self) -> None: the callable specified by user """ if self.private_attributes_callable is not None: - self.private_attributes = self.private_attributes_callable(**self.kwargs) + self.private_attributes = self.private_attributes_callable( + **self.kwargs + ) def __set_collaborator_attrs_to_clone(self, clone: Any) -> None: """ @@ -119,7 +120,9 @@ def __delete_collab_attrs_from_clone(self, clone: Any) -> None: # parameters from clone, then delete attributes from clone. for attr_name in self.private_attributes: if hasattr(clone, attr_name): - self.private_attributes.update({attr_name: getattr(clone, attr_name)}) + self.private_attributes.update( + {attr_name: getattr(clone, attr_name)} + ) delattr(clone, attr_name) def execute_func(self, ctx: Any, f_name: str, callback: Callable) -> Any: @@ -194,7 +197,9 @@ def initialize_private_attributes(self) -> None: the callable specified by user """ if self.private_attributes_callable is not None: - self.private_attributes = self.private_attributes_callable(**self.kwargs) + self.private_attributes = self.private_attributes_callable( + **self.kwargs + ) def __set_agg_attrs_to_clone(self, clone: Any) -> None: """ @@ -215,11 +220,18 @@ def __delete_agg_attrs_from_clone(self, clone: Any) -> None: # parameters from clone, then delete attributes from clone. for attr_name in self.private_attributes: if hasattr(clone, attr_name): - self.private_attributes.update({attr_name: getattr(clone, attr_name)}) + self.private_attributes.update( + {attr_name: getattr(clone, attr_name)} + ) delattr(clone, attr_name) - def execute_func(self, ctx: Any, f_name: str, callback: Callable, - clones: Optional[Any] = None) -> Any: + def execute_func( + self, + ctx: Any, + f_name: str, + callback: Callable, + clones: Optional[Any] = None, + ) -> Any: """ Execute remote function f """ diff --git a/openfl/experimental/placement/__init__.py b/openfl/experimental/placement/__init__.py index 05b12d50bb..b0c05b1b1b 100644 --- a/openfl/experimental/placement/__init__.py +++ b/openfl/experimental/placement/__init__.py @@ -1,8 +1,6 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """openfl.experimental.placement package.""" -from .placement import aggregator, collaborator - -__all__ = ["aggregator", "collaborator"] +# FIXME: Unnecessary recursion. +from openfl.experimental.placement.placement import aggregator, collaborator diff --git a/openfl/experimental/placement/placement.py b/openfl/experimental/placement/placement.py index a66b47f72c..f7ba1f16e2 100644 --- a/openfl/experimental/placement/placement.py +++ b/openfl/experimental/placement/placement.py @@ -2,9 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 import functools -from openfl.experimental.utilities import RedirectStdStreamContext from typing import Callable +from openfl.experimental.utilities import RedirectStdStreamContext + def aggregator(f: Callable = None) -> Callable: """ diff --git a/openfl/experimental/protocols/interceptors.py b/openfl/experimental/protocols/interceptors.py index a54ff76d82..02f9c1b6d1 100644 --- a/openfl/experimental/protocols/interceptors.py +++ b/openfl/experimental/protocols/interceptors.py @@ -1,44 +1,52 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """gRPC interceptors module.""" import collections import grpc -class _GenericClientInterceptor(grpc.UnaryUnaryClientInterceptor, - grpc.UnaryStreamClientInterceptor, - grpc.StreamUnaryClientInterceptor, - grpc.StreamStreamClientInterceptor): +class _GenericClientInterceptor( + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor, +): def __init__(self, interceptor_function): self._fn = interceptor_function def intercept_unary_unary(self, continuation, client_call_details, request): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, iter((request,)), False, False) + client_call_details, iter((request,)), False, False + ) response = continuation(new_details, next(new_request_iterator)) return postprocess(response) if postprocess else response - def intercept_unary_stream(self, continuation, client_call_details, - request): + def intercept_unary_stream( + self, continuation, client_call_details, request + ): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, iter((request,)), False, True) + client_call_details, iter((request,)), False, True + ) response_it = continuation(new_details, next(new_request_iterator)) return postprocess(response_it) if postprocess else response_it - def intercept_stream_unary(self, continuation, client_call_details, - request_iterator): + def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, request_iterator, True, False) + client_call_details, request_iterator, True, False + ) response = continuation(new_details, new_request_iterator) return postprocess(response) if postprocess else response - def intercept_stream_stream(self, continuation, client_call_details, - request_iterator): + def intercept_stream_stream( + self, continuation, client_call_details, request_iterator + ): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, request_iterator, True, True) + client_call_details, request_iterator, True, True + ) response_it = continuation(new_details, new_request_iterator) return postprocess(response_it) if postprocess else response_it @@ -49,10 +57,9 @@ def _create_generic_interceptor(intercept_call): class _ClientCallDetails( collections.namedtuple( - '_ClientCallDetails', - ('method', 'timeout', 'metadata', 'credentials') + "_ClientCallDetails", ("method", "timeout", "metadata", "credentials") ), - grpc.ClientCallDetails + grpc.ClientCallDetails, ): pass @@ -60,19 +67,28 @@ class _ClientCallDetails( def headers_adder(headers): """Create interceptor with added headers.""" - def intercept_call(client_call_details, request_iterator, request_streaming, - response_streaming): + def intercept_call( + client_call_details, + request_iterator, + request_streaming, + response_streaming, + ): metadata = [] if client_call_details.metadata is not None: metadata = list(client_call_details.metadata) for header, value in headers.items(): - metadata.append(( - header, - value, - )) + metadata.append( + ( + header, + value, + ) + ) client_call_details = _ClientCallDetails( - client_call_details.method, client_call_details.timeout, metadata, - client_call_details.credentials) + client_call_details.method, + client_call_details.timeout, + metadata, + client_call_details.credentials, + ) return client_call_details, request_iterator, None return _create_generic_interceptor(intercept_call) diff --git a/openfl/experimental/protocols/utils.py b/openfl/experimental/protocols/utils.py index fc6edc7bae..5d2c5c8df5 100644 --- a/openfl/experimental/protocols/utils.py +++ b/openfl/experimental/protocols/utils.py @@ -21,55 +21,65 @@ def model_proto_to_bytes_and_metadata(model_proto): round_number = None for tensor_proto in model_proto.tensors: bytes_dict[tensor_proto.name] = tensor_proto.data_bytes - metadata_dict[tensor_proto.name] = [{ - 'int_to_float': proto.int_to_float, - 'int_list': proto.int_list, - 'bool_list': proto.bool_list - } + metadata_dict[tensor_proto.name] = [ + { + "int_to_float": proto.int_to_float, + "int_list": proto.int_list, + "bool_list": proto.bool_list, + } for proto in tensor_proto.transformer_metadata ] if round_number is None: round_number = tensor_proto.round_number else: assert round_number == tensor_proto.round_number, ( - f'Round numbers in model are inconsistent: {round_number} ' - f'and {tensor_proto.round_number}' + f"Round numbers in model are inconsistent: {round_number} " + f"and {tensor_proto.round_number}" ) return bytes_dict, metadata_dict, round_number -def bytes_and_metadata_to_model_proto(bytes_dict, model_id, model_version, - is_delta, metadata_dict): +def bytes_and_metadata_to_model_proto( + bytes_dict, model_id, model_version, is_delta, metadata_dict +): """Convert bytes and metadata to model protobuf.""" - model_header = ModelHeader(id=model_id, version=model_version, is_delta=is_delta) # NOQA:F821 + model_header = ModelHeader( + id=model_id, version=model_version, is_delta=is_delta + ) # NOQA:F821 tensor_protos = [] for key, data_bytes in bytes_dict.items(): transformer_metadata = metadata_dict[key] metadata_protos = [] for metadata in transformer_metadata: - if metadata.get('int_to_float') is not None: - int_to_float = metadata.get('int_to_float') + if metadata.get("int_to_float") is not None: + int_to_float = metadata.get("int_to_float") else: int_to_float = {} - if metadata.get('int_list') is not None: - int_list = metadata.get('int_list') + if metadata.get("int_list") is not None: + int_list = metadata.get("int_list") else: int_list = [] - if metadata.get('bool_list') is not None: - bool_list = metadata.get('bool_list') + if metadata.get("bool_list") is not None: + bool_list = metadata.get("bool_list") else: bool_list = [] - metadata_protos.append(base_pb2.MetadataProto( - int_to_float=int_to_float, - int_list=int_list, - bool_list=bool_list, - )) - tensor_protos.append(TensorProto(name=key, # NOQA:F821 - data_bytes=data_bytes, - transformer_metadata=metadata_protos)) + metadata_protos.append( + base_pb2.MetadataProto( + int_to_float=int_to_float, + int_list=int_list, + bool_list=bool_list, + ) + ) + tensor_protos.append( + TensorProto( + name=key, # NOQA:F821 + data_bytes=data_bytes, + transformer_metadata=metadata_protos, + ) + ) return base_pb2.ModelProto(header=model_header, tensors=tensor_protos) @@ -77,25 +87,27 @@ def construct_named_tensor(tensor_key, nparray, transformer_metadata, lossless): """Construct named tensor.""" metadata_protos = [] for metadata in transformer_metadata: - if metadata.get('int_to_float') is not None: - int_to_float = metadata.get('int_to_float') + if metadata.get("int_to_float") is not None: + int_to_float = metadata.get("int_to_float") else: int_to_float = {} - if metadata.get('int_list') is not None: - int_list = metadata.get('int_list') + if metadata.get("int_list") is not None: + int_list = metadata.get("int_list") else: int_list = [] - if metadata.get('bool_list') is not None: - bool_list = metadata.get('bool_list') + if metadata.get("bool_list") is not None: + bool_list = metadata.get("bool_list") else: bool_list = [] - metadata_protos.append(base_pb2.MetadataProto( - int_to_float=int_to_float, - int_list=int_list, - bool_list=bool_list, - )) + metadata_protos.append( + base_pb2.MetadataProto( + int_to_float=int_to_float, + int_list=int_list, + bool_list=bool_list, + ) + ) tensor_name, origin, round_number, report, tags = tensor_key @@ -110,21 +122,27 @@ def construct_named_tensor(tensor_key, nparray, transformer_metadata, lossless): ) -def construct_proto(tensor_dict, model_id, model_version, is_delta, compression_pipeline): +def construct_proto( + tensor_dict, model_id, model_version, is_delta, compression_pipeline +): """Construct proto.""" # compress the arrays in the tensor_dict, and form the model proto # TODO: Hold-out tensors from the compression pipeline. bytes_dict = {} metadata_dict = {} for key, array in tensor_dict.items(): - bytes_dict[key], metadata_dict[key] = compression_pipeline.forward(data=array) + bytes_dict[key], metadata_dict[key] = compression_pipeline.forward( + data=array + ) # convert the compressed_tensor_dict and metadata to protobuf, and make the new model proto - model_proto = bytes_and_metadata_to_model_proto(bytes_dict=bytes_dict, - model_id=model_id, - model_version=model_version, - is_delta=is_delta, - metadata_dict=metadata_dict) + model_proto = bytes_and_metadata_to_model_proto( + bytes_dict=bytes_dict, + model_id=model_id, + model_version=model_version, + is_delta=is_delta, + metadata_dict=metadata_dict, + ) return model_proto @@ -135,13 +153,15 @@ def construct_model_proto(tensor_dict, round_number, tensor_pipe): named_tensors = [] for key, nparray in tensor_dict.items(): bytes_data, transformer_metadata = tensor_pipe.forward(data=nparray) - tensor_key = TensorKey(key, 'agg', round_number, False, ('model',)) - named_tensors.append(construct_named_tensor( - tensor_key, - bytes_data, - transformer_metadata, - lossless=True, - )) + tensor_key = TensorKey(key, "agg", round_number, False, ("model",)) + named_tensors.append( + construct_named_tensor( + tensor_key, + bytes_data, + transformer_metadata, + lossless=True, + ) + ) return base_pb2.ModelProto(tensors=named_tensors) @@ -149,15 +169,18 @@ def construct_model_proto(tensor_dict, round_number, tensor_pipe): def deconstruct_model_proto(model_proto, compression_pipeline): """Deconstruct model proto.""" # extract the tensor_dict and metadata - bytes_dict, metadata_dict, round_number = model_proto_to_bytes_and_metadata(model_proto) + bytes_dict, metadata_dict, round_number = model_proto_to_bytes_and_metadata( + model_proto + ) # decompress the tensors # TODO: Handle tensors meant to be held-out from the compression pipeline # (currently none are held out). tensor_dict = {} for key in bytes_dict: - tensor_dict[key] = compression_pipeline.backward(data=bytes_dict[key], - transformer_metadata=metadata_dict[key]) + tensor_dict[key] = compression_pipeline.backward( + data=bytes_dict[key], transformer_metadata=metadata_dict[key] + ) return tensor_dict, round_number @@ -179,8 +202,9 @@ def deconstruct_proto(model_proto, compression_pipeline): # (currently none are held out). tensor_dict = {} for key in bytes_dict: - tensor_dict[key] = compression_pipeline.backward(data=bytes_dict[key], - transformer_metadata=metadata_dict[key]) + tensor_dict[key] = compression_pipeline.backward( + data=bytes_dict[key], transformer_metadata=metadata_dict[key] + ) return tensor_dict @@ -193,7 +217,7 @@ def load_proto(fpath): Returns: protobuf: A protobuf of the model """ - with open(fpath, 'rb') as f: + with open(fpath, "rb") as f: loaded = f.read() model = base_pb2.ModelProto().FromString(loaded) return model @@ -208,7 +232,7 @@ def dump_proto(model_proto, fpath): """ s = model_proto.SerializeToString() - with open(fpath, 'wb') as f: + with open(fpath, "wb") as f: f.write(s) @@ -223,17 +247,19 @@ def datastream_to_proto(proto, stream, logger=None): Returns: protobuf: A protobuf of the model """ - npbytes = b'' + npbytes = b"" for chunk in stream: npbytes += chunk.npbytes if len(npbytes) > 0: proto.ParseFromString(npbytes) if logger is not None: - logger.debug(f'datastream_to_proto parsed a {type(proto)}.') + logger.debug(f"datastream_to_proto parsed a {type(proto)}.") return proto else: - raise RuntimeError(f'Received empty stream message of type {type(proto)}') + raise RuntimeError( + f"Received empty stream message of type {type(proto)}" + ) def proto_to_datastream(proto, logger, max_buffer_size=(2 * 1024 * 1024)): @@ -249,10 +275,12 @@ def proto_to_datastream(proto, logger, max_buffer_size=(2 * 1024 * 1024)): npbytes = proto.SerializeToString() data_size = len(npbytes) buffer_size = data_size if max_buffer_size > data_size else max_buffer_size - logger.debug(f'Setting stream chunks with size {buffer_size} for proto of type {type(proto)}') + logger.debug( + f"Setting stream chunks with size {buffer_size} for proto of type {type(proto)}" + ) for i in range(0, data_size, buffer_size): - chunk = npbytes[i: i + buffer_size] + chunk = npbytes[i : i + buffer_size] reply = base_pb2.DataStream(npbytes=chunk, size=len(chunk)) yield reply diff --git a/openfl/experimental/runtime/__init__.py b/openfl/experimental/runtime/__init__.py index 488e4b53bb..195b42fe5d 100644 --- a/openfl/experimental/runtime/__init__.py +++ b/openfl/experimental/runtime/__init__.py @@ -1,11 +1,7 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """ openfl.experimental.runtime package Runtime class.""" -from .runtime import Runtime -from .local_runtime import LocalRuntime -from .federated_runtime import FederatedRuntime - - -__all__ = ["FederatedRuntime", "LocalRuntime", "Runtime"] +from openfl.experimental.runtime.federated_runtime import FederatedRuntime +from openfl.experimental.runtime.local_runtime import LocalRuntime +from openfl.experimental.runtime.runtime import Runtime diff --git a/openfl/experimental/runtime/federated_runtime.py b/openfl/experimental/runtime/federated_runtime.py index da7ef3efb2..bae51c8fe3 100644 --- a/openfl/experimental/runtime/federated_runtime.py +++ b/openfl/experimental/runtime/federated_runtime.py @@ -1,21 +1,22 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """ openfl.experimental.runtime package LocalRuntime class.""" from __future__ import annotations -from openfl.experimental.runtime import Runtime + from typing import TYPE_CHECKING +from openfl.experimental.runtime.runtime import Runtime + if TYPE_CHECKING: from openfl.experimental.interface import Aggregator from openfl.experimental.interface import Collaborator -from typing import List -from typing import Type +from typing import List, Type class FederatedRuntime(Runtime): + def __init__( self, aggregator: str = None, diff --git a/openfl/experimental/runtime/local_runtime.py b/openfl/experimental/runtime/local_runtime.py index 4bf46bc141..19f8de0041 100644 --- a/openfl/experimental/runtime/local_runtime.py +++ b/openfl/experimental/runtime/local_runtime.py @@ -1,35 +1,38 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """ openfl.experimental.runtime package LocalRuntime class.""" from __future__ import annotations -from copy import deepcopy + +import gc import importlib -import ray +import math import os -import gc -from openfl.experimental.runtime import Runtime +from copy import deepcopy from typing import TYPE_CHECKING, Optional -import math + +import ray + +from openfl.experimental.runtime.runtime import Runtime if TYPE_CHECKING: from openfl.experimental.interface import Aggregator, Collaborator, FLSpec +from typing import Any, Callable, Dict, List, Type + from openfl.experimental.utilities import ( ResourcesNotAvailableError, aggregator_to_collaborator, - generate_artifacts, - filter_attributes, + check_resource_allocation, checkpoint, + filter_attributes, + generate_artifacts, get_number_of_gpus, - check_resource_allocation, ) -from typing import List, Any -from typing import Dict, Type, Callable class RayExecutor: + def __init__(self): """Create RayExecutor object""" self.__remote_contexts = [] @@ -103,14 +106,18 @@ def __init__(self, collaborator_actor, collaborator): for method in dir(Collaborator) if callable(getattr(Collaborator, method)) ] - external_methods = [method for method in all_methods if (method[0] != "_")] + external_methods = [ + method for method in all_methods if (method[0] != "_") + ] self.collaborator_actor = collaborator_actor self.collaborator = collaborator for method in external_methods: setattr( self, method, - RemoteHelper(self.collaborator_actor, self.collaborator, method), + RemoteHelper( + self.collaborator_actor, self.collaborator, method + ), ) class RemoteHelper: @@ -138,10 +145,8 @@ def __init__(self, collaborator_actor, collaborator, f_name) -> None: self.f_name = f_name self.collaborator_actor = collaborator_actor self.collaborator = collaborator - self.f = ( - lambda *args, **kwargs: self.collaborator_actor.execute_from_col.remote( - self.collaborator, self.f_name, *args, **kwargs - ) + self.f = lambda *args, **kwargs: self.collaborator_actor.execute_from_col.remote( + self.collaborator, self.f_name, *args, **kwargs ) def remote(self, *args, **kwargs): @@ -178,7 +183,7 @@ def remote(self, *args, **kwargs): [ i.num_cpus for i in collaborators_sorted_by_gpucpu[ - times_called: times_called + collaborators_per_group + times_called : times_called + collaborators_per_group ] ] ) @@ -186,7 +191,7 @@ def remote(self, *args, **kwargs): [ i.num_gpus for i in collaborators_sorted_by_gpucpu[ - times_called: times_called + collaborators_per_group + times_called : times_called + collaborators_per_group ] ] ) @@ -287,6 +292,7 @@ def get_collaborator(self, name): class LocalRuntime(Runtime): + def __init__( self, aggregator: Dict = None, @@ -378,7 +384,9 @@ def __get_aggregator_object(self, aggregator: Type[Aggregator]) -> Any: ({agg_cpus} < {total_available_cpus})." ) - interface_module = importlib.import_module("openfl.experimental.interface") + interface_module = importlib.import_module( + "openfl.experimental.interface" + ) aggregator_class = getattr(interface_module, "Aggregator") aggregator_actor = ray.remote(aggregator_class).options( @@ -435,10 +443,12 @@ def collaborators(self) -> List[str]: def collaborators(self, collaborators: List[Type[Collaborator]]): """Set LocalRuntime collaborators""" if self.backend == "single_process": + def get_collab_name(collab): return collab.get_name() else: + def get_collab_name(collab): return ray.get(collab.get_name.remote()) @@ -463,7 +473,9 @@ def get_collaborator_kwargs(self, collaborator_name: str): if hasattr(collab, "private_attributes_callable"): if collab.private_attributes_callable is not None: kwargs.update(collab.kwargs) - kwargs["private_attributes_callable"] = collab.private_attributes_callable.__name__ + kwargs["private_attributes_callable"] = ( + collab.private_attributes_callable.__name__ + ) return kwargs @@ -499,7 +511,9 @@ def restore_instance_snapshot( if not hasattr(ctx, name): setattr(ctx, name, attr) - def execute_agg_steps(self, ctx: Any, f_name: str, clones: Optional[Any] = None): + def execute_agg_steps( + self, ctx: Any, f_name: str, clones: Optional[Any] = None + ): """ Execute aggregator steps until at transition point """ @@ -513,7 +527,10 @@ def execute_agg_steps(self, ctx: Any, f_name: str, clones: Optional[Any] = None) f() f, parent_func = ctx.execute_task_args[:2] - if aggregator_to_collaborator(f, parent_func) or f.__name__ == "end": + if ( + aggregator_to_collaborator(f, parent_func) + or f.__name__ == "end" + ): not_at_transition_point = False f_name = f.__name__ @@ -557,7 +574,9 @@ def execute_task(self, flspec_obj: Type[FLSpec], f: Callable, **kwargs): ) else: flspec_obj = self.execute_agg_task(flspec_obj, f) - f, parent_func, instance_snapshot, kwargs = flspec_obj.execute_task_args + f, parent_func, instance_snapshot, kwargs = ( + flspec_obj.execute_task_args + ) else: flspec_obj = self.execute_agg_task(flspec_obj, f) f = flspec_obj.execute_task_args[0] @@ -577,17 +596,24 @@ def execute_agg_task(self, flspec_obj, f): flspec_obj: updated FLSpec (flow) object """ from openfl.experimental.interface import FLSpec + aggregator = self._aggregator clones = None if self.join_step: - clones = [FLSpec._clones[col] for col in self.selected_collaborators] + clones = [ + FLSpec._clones[col] for col in self.selected_collaborators + ] self.join_step = False if self.backend == "ray": ray_executor = RayExecutor() ray_executor.ray_call_put( - aggregator, flspec_obj, f.__name__, self.execute_agg_steps, clones + aggregator, + flspec_obj, + f.__name__, + self.execute_agg_steps, + clones, ) flspec_obj = ray_executor.ray_call_get()[0] del ray_executor @@ -620,16 +646,16 @@ def execute_collab_task( flspec_obj: updated FLSpec (flow) object """ - from openfl.experimental.interface import ( - FLSpec, - ) + from openfl.experimental.interface import FLSpec flspec_obj._foreach_methods.append(f.__name__) selected_collaborators = getattr(flspec_obj, kwargs["foreach"]) self.selected_collaborators = selected_collaborators # filter exclude/include attributes for clone - self.filter_exclude_include(flspec_obj, f, selected_collaborators, **kwargs) + self.filter_exclude_include( + flspec_obj, f, selected_collaborators, **kwargs + ) if self.backend == "ray": ray_executor = RayExecutor() @@ -654,7 +680,9 @@ def execute_collab_task( collaborator, clone, f.__name__, self.execute_collab_steps ) else: - collaborator.execute_func(clone, f.__name__, self.execute_collab_steps) + collaborator.execute_func( + clone, f.__name__, self.execute_collab_steps + ) if self.backend == "ray": clones = ray_executor.ray_call_get() @@ -673,7 +701,9 @@ def execute_collab_task( self.join_step = True return flspec_obj - def filter_exclude_include(self, flspec_obj, f, selected_collaborators, **kwargs): + def filter_exclude_include( + self, flspec_obj, f, selected_collaborators, **kwargs + ): """ This function filters exclude/include attributes Args: @@ -682,16 +712,14 @@ def filter_exclude_include(self, flspec_obj, f, selected_collaborators, **kwargs selected_collaborators : all collaborators """ - from openfl.experimental.interface import ( - FLSpec, - ) + from openfl.experimental.interface import FLSpec for col in selected_collaborators: clone = FLSpec._clones[col] clone.input = col - if ("exclude" in kwargs and hasattr(clone, kwargs["exclude"][0])) or ( - "include" in kwargs and hasattr(clone, kwargs["include"][0]) - ): + if ( + "exclude" in kwargs and hasattr(clone, kwargs["exclude"][0]) + ) or ("include" in kwargs and hasattr(clone, kwargs["include"][0])): filter_attributes(clone, f, **kwargs) artifacts_iter, _ = generate_artifacts(ctx=flspec_obj) for name, attr in artifacts_iter(): diff --git a/openfl/experimental/runtime/runtime.py b/openfl/experimental/runtime/runtime.py index 3b769b5995..a9e5a5d9e3 100644 --- a/openfl/experimental/runtime/runtime.py +++ b/openfl/experimental/runtime/runtime.py @@ -1,16 +1,18 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """ openfl.experimental.runtime module Runtime class.""" from __future__ import annotations + from typing import TYPE_CHECKING + if TYPE_CHECKING: from openfl.experimental.interface import Aggregator, Collaborator, FLSpec -from typing import List -from typing import Callable + +from typing import Callable, List class Runtime: + def __init__(self): """ Base interface for runtimes that can run FLSpec flows @@ -46,7 +48,7 @@ def execute_task( f: Callable, parent_func: Callable, instance_snapshot: List[FLSpec] = [], - **kwargs + **kwargs, ): """ Performs the execution of a task as defined by the diff --git a/openfl/experimental/transport/__init__.py b/openfl/experimental/transport/__init__.py index 5b20dba61b..37a10d93f9 100644 --- a/openfl/experimental/transport/__init__.py +++ b/openfl/experimental/transport/__init__.py @@ -1,12 +1,7 @@ # Copyright (C) 2020-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """openfl.experimental.transport package.""" -from .grpc import AggregatorGRPCClient -from .grpc import AggregatorGRPCServer - - -__all__ = [ - 'AggregatorGRPCServer', - 'AggregatorGRPCClient', -] +from openfl.experimental.transport.grpc import ( + AggregatorGRPCClient, + AggregatorGRPCServer, +) diff --git a/openfl/experimental/transport/grpc/__init__.py b/openfl/experimental/transport/grpc/__init__.py index 270fc493c7..2b66ade490 100644 --- a/openfl/experimental/transport/grpc/__init__.py +++ b/openfl/experimental/transport/grpc/__init__.py @@ -1,18 +1,15 @@ # Copyright (C) 2020-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """openfl.experimental.transport.grpc package.""" -from .aggregator_client import AggregatorGRPCClient -from .aggregator_server import AggregatorGRPCServer +from openfl.experimental.transport.grpc.aggregator_client import ( + AggregatorGRPCClient, +) +from openfl.experimental.transport.grpc.aggregator_server import ( + AggregatorGRPCServer, +) +# FIXME: Not the right place for exceptions class ShardNotFoundError(Exception): """Indicates that director has no information about that shard.""" - - -__all__ = [ - 'AggregatorGRPCServer', - 'AggregatorGRPCClient', - 'ShardNotFoundError', -] diff --git a/openfl/experimental/transport/grpc/aggregator_client.py b/openfl/experimental/transport/grpc/aggregator_client.py index 3982a7031a..ba04b7d629 100644 --- a/openfl/experimental/transport/grpc/aggregator_client.py +++ b/openfl/experimental/transport/grpc/aggregator_client.py @@ -1,21 +1,19 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """AggregatorGRPCClient module.""" import time from logging import getLogger -from typing import Optional -from typing import Tuple +from typing import Optional, Tuple import grpc -from openfl.experimental.protocols import aggregator_pb2 -from openfl.experimental.protocols import aggregator_pb2_grpc +from openfl.experimental.protocols import aggregator_pb2, aggregator_pb2_grpc +from openfl.experimental.transport.grpc.grpc_channel_options import ( + channel_options, +) from openfl.utilities import check_equal -from .grpc_channel_options import channel_options - class ConstantBackoff: """Constant Backoff policy.""" @@ -28,7 +26,7 @@ def __init__(self, reconnect_interval, logger, uri): def sleep(self): """Sleep for specified interval.""" - self.logger.info(f'Attempting to connect to aggregator at {self.uri}') + self.logger.info(f"Attempting to connect to aggregator at {self.uri}") time.sleep(self.reconnect_interval) @@ -38,15 +36,17 @@ class RetryOnRpcErrorClientInterceptor( """Retry gRPC connection on failure.""" def __init__( - self, - sleeping_policy, - status_for_retry: Optional[Tuple[grpc.StatusCode]] = None, + self, + sleeping_policy, + status_for_retry: Optional[Tuple[grpc.StatusCode]] = None, ): """Initialize function for gRPC retry.""" self.sleeping_policy = sleeping_policy self.status_for_retry = status_for_retry - def _intercept_call(self, continuation, client_call_details, request_or_iterator): + def _intercept_call( + self, continuation, client_call_details, request_or_iterator + ): """Intercept the call to the gRPC server.""" while True: response = continuation(client_call_details, request_or_iterator) @@ -54,10 +54,12 @@ def _intercept_call(self, continuation, client_call_details, request_or_iterator if isinstance(response, grpc.RpcError): # If status code is not in retryable status codes - self.sleeping_policy.logger.info(f'Response code: {response.code()}') + self.sleeping_policy.logger.info( + f"Response code: {response.code()}" + ) if ( - self.status_for_retry - and response.code() not in self.status_for_retry + self.status_for_retry + and response.code() not in self.status_for_retry ): return response @@ -70,13 +72,16 @@ def intercept_unary_unary(self, continuation, client_call_details, request): return self._intercept_call(continuation, client_call_details, request) def intercept_stream_unary( - self, continuation, client_call_details, request_iterator + self, continuation, client_call_details, request_iterator ): """Wrap intercept call for stream->unary RPC.""" - return self._intercept_call(continuation, client_call_details, request_iterator) + return self._intercept_call( + continuation, client_call_details, request_iterator + ) def _atomic_connection(func): + def wrapper(self, *args, **kwargs): self.reconnect() response = func(self, *args, **kwargs) @@ -87,6 +92,7 @@ def wrapper(self, *args, **kwargs): def _resend_data_on_reconnection(func): + def wrapper(self, *args, **kwargs): while True: try: @@ -94,7 +100,7 @@ def wrapper(self, *args, **kwargs): except grpc.RpcError as e: if e.code() == grpc.StatusCode.UNKNOWN: self.logger.info( - f'Attempting to resend data request to aggregator at {self.uri}' + f"Attempting to resend data request to aggregator at {self.uri}" ) elif e.code() == grpc.StatusCode.UNAUTHENTICATED: raise @@ -108,20 +114,22 @@ def wrapper(self, *args, **kwargs): class AggregatorGRPCClient: """Client to the aggregator over gRPC-TLS.""" - def __init__(self, - agg_addr, - agg_port, - tls, - disable_client_auth, - root_certificate, - certificate, - private_key, - aggregator_uuid=None, - federation_uuid=None, - single_col_cert_common_name=None, - **kwargs): + def __init__( + self, + agg_addr, + agg_port, + tls, + disable_client_auth, + root_certificate, + certificate, + private_key, + aggregator_uuid=None, + federation_uuid=None, + single_col_cert_common_name=None, + **kwargs, + ): """Initialize.""" - self.uri = f'{agg_addr}:{agg_port}' + self.uri = f"{agg_addr}:{agg_port}" self.tls = tls self.disable_client_auth = disable_client_auth self.root_certificate = root_certificate @@ -132,7 +140,8 @@ def __init__(self, if not self.tls: self.logger.warn( - 'gRPC is running on insecure channel with TLS disabled.') + "gRPC is running on insecure channel with TLS disabled." + ) self.channel = self.create_insecure_channel(self.uri) else: self.channel = self.create_tls_channel( @@ -140,7 +149,7 @@ def __init__(self, self.root_certificate, self.disable_client_auth, self.certificate, - self.private_key + self.private_key, ) self.header = None @@ -153,8 +162,11 @@ def __init__(self, RetryOnRpcErrorClientInterceptor( sleeping_policy=ConstantBackoff( logger=self.logger, - reconnect_interval=int(kwargs.get('client_reconnect_interval', 1)), - uri=self.uri), + reconnect_interval=int( + kwargs.get("client_reconnect_interval", 1) + ), + uri=self.uri, + ), status_for_retry=(grpc.StatusCode.UNAVAILABLE,), ), ) @@ -177,8 +189,14 @@ def create_insecure_channel(self, uri): """ return grpc.insecure_channel(uri, options=channel_options) - def create_tls_channel(self, uri, root_certificate, disable_client_auth, - certificate, private_key): + def create_tls_channel( + self, + uri, + root_certificate, + disable_client_auth, + certificate, + private_key, + ): """ Set an secure gRPC channel (i.e. TLS). @@ -193,17 +211,17 @@ def create_tls_channel(self, uri, root_certificate, disable_client_auth, Returns: An insecure gRPC channel object """ - with open(root_certificate, 'rb') as f: + with open(root_certificate, "rb") as f: root_certificate_b = f.read() if disable_client_auth: - self.logger.warn('Client-side authentication is disabled.') + self.logger.warn("Client-side authentication is disabled.") private_key_b = None certificate_b = None else: - with open(private_key, 'rb') as f: + with open(private_key, "rb") as f: private_key_b = f.read() - with open(certificate, 'rb') as f: + with open(certificate, "rb") as f: certificate_b = f.read() credentials = grpc.ssl_channel_credentials( @@ -212,15 +230,14 @@ def create_tls_channel(self, uri, root_certificate, disable_client_auth, certificate_chain=certificate_b, ) - return grpc.secure_channel( - uri, credentials, options=channel_options) + return grpc.secure_channel(uri, credentials, options=channel_options) def _set_header(self, collaborator_name): self.header = aggregator_pb2.MessageHeader( sender=collaborator_name, receiver=self.aggregator_uuid, federation_uuid=self.federation_uuid, - single_col_cert_common_name=self.single_col_cert_common_name or '' + single_col_cert_common_name=self.single_col_cert_common_name or "", ) def validate_response(self, reply, collaborator_name): @@ -231,21 +248,19 @@ def validate_response(self, reply, collaborator_name): # check that federation id matches check_equal( - reply.header.federation_uuid, - self.federation_uuid, - self.logger + reply.header.federation_uuid, self.federation_uuid, self.logger ) # check that there is aggrement on the single_col_cert_common_name check_equal( reply.header.single_col_cert_common_name, - self.single_col_cert_common_name or '', - self.logger + self.single_col_cert_common_name or "", + self.logger, ) def disconnect(self): """Close the gRPC channel.""" - self.logger.debug(f'Disconnecting from gRPC server at {self.uri}') + self.logger.debug(f"Disconnecting from gRPC server at {self.uri}") self.channel.close() def reconnect(self): @@ -261,10 +276,10 @@ def reconnect(self): self.root_certificate, self.disable_client_auth, self.certificate, - self.private_key + self.private_key, ) - self.logger.debug(f'Connecting to gRPC at {self.uri}') + self.logger.debug(f"Connecting to gRPC at {self.uri}") self.stub = aggregator_pb2_grpc.AggregatorStub( grpc.intercept_channel(self.channel, *self.interceptors) @@ -272,8 +287,9 @@ def reconnect(self): @_atomic_connection @_resend_data_on_reconnection - def send_task_results(self, collaborator_name, round_number, next_step, - clone_bytes): + def send_task_results( + self, collaborator_name, round_number, next_step, clone_bytes + ): """Send next function name to aggregator.""" self._set_header(collaborator_name) request = aggregator_pb2.TaskResultsRequest( @@ -281,7 +297,7 @@ def send_task_results(self, collaborator_name, round_number, next_step, collab_name=collaborator_name, round_number=round_number, next_step=next_step, - execution_environment=clone_bytes + execution_environment=clone_bytes, ) response = self.stub.SendTaskResults(request) @@ -299,12 +315,19 @@ def get_tasks(self, collaborator_name): response = self.stub.GetTasks(request) self.validate_response(response, collaborator_name) - return (response.round_number, response.function_name, - response.execution_environment, response.sleep_time, response.quit) + return ( + response.round_number, + response.function_name, + response.execution_environment, + response.sleep_time, + response.quit, + ) @_atomic_connection @_resend_data_on_reconnection - def call_checkpoint(self, collaborator_name, clone_bytes, function, stream_buffer): + def call_checkpoint( + self, collaborator_name, clone_bytes, function, stream_buffer + ): """Perform checkpoint for collaborator task.""" self._set_header(collaborator_name) diff --git a/openfl/experimental/transport/grpc/aggregator_server.py b/openfl/experimental/transport/grpc/aggregator_server.py index 5675036e43..e85ed17e87 100644 --- a/openfl/experimental/transport/grpc/aggregator_server.py +++ b/openfl/experimental/transport/grpc/aggregator_server.py @@ -1,24 +1,20 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """AggregatorGRPCServer module.""" import logging from concurrent.futures import ThreadPoolExecutor -from random import random from multiprocessing import cpu_count +from random import random from time import sleep -from grpc import server -from grpc import ssl_server_credentials -from grpc import StatusCode +from grpc import StatusCode, server, ssl_server_credentials -from openfl.experimental.protocols import aggregator_pb2 -from openfl.experimental.protocols import aggregator_pb2_grpc -from openfl.utilities import check_equal -from openfl.utilities import check_is_in - -from .grpc_channel_options import channel_options +from openfl.experimental.protocols import aggregator_pb2, aggregator_pb2_grpc +from openfl.experimental.transport.grpc.grpc_channel_options import ( + channel_options, +) +from openfl.utilities import check_equal, check_is_in logger = logging.getLogger(__name__) @@ -26,15 +22,17 @@ class AggregatorGRPCServer(aggregator_pb2_grpc.AggregatorServicer): """gRPC server class for the Aggregator.""" - def __init__(self, - aggregator, - agg_port, - tls=True, - disable_client_auth=False, - root_certificate=None, - certificate=None, - private_key=None, - **kwargs): + def __init__( + self, + aggregator, + agg_port, + tls=True, + disable_client_auth=False, + root_certificate=None, + certificate=None, + private_key=None, + **kwargs, + ): """ Class initializer. @@ -51,7 +49,7 @@ def __init__(self, kwargs (dict): Additional arguments to pass into function """ self.aggregator = aggregator - self.uri = f'[::]:{agg_port}' + self.uri = f"[::]:{agg_port}" self.tls = tls self.disable_client_auth = disable_client_auth self.root_certificate = root_certificate @@ -76,17 +74,20 @@ def validate_collaborator(self, request, context): """ if self.tls: - common_name = context.auth_context()[ - 'x509_common_name'][0].decode('utf-8') + common_name = context.auth_context()["x509_common_name"][0].decode( + "utf-8" + ) collaborator_common_name = request.header.sender if not self.aggregator.valid_collaborator_cn_and_id( - common_name, collaborator_common_name): + common_name, collaborator_common_name + ): # Random delay in authentication failures sleep(5 * random()) context.abort( StatusCode.UNAUTHENTICATED, - f'Invalid collaborator. CN: |{common_name}| ' - f'collaborator_common_name: |{collaborator_common_name}|') + f"Invalid collaborator. CN: |{common_name}| " + f"collaborator_common_name: |{collaborator_common_name}|", + ) def get_header(self, collaborator_name): """ @@ -100,7 +101,7 @@ def get_header(self, collaborator_name): sender=self.aggregator.uuid, receiver=collaborator_name, federation_uuid=self.aggregator.federation_uuid, - single_col_cert_common_name=self.aggregator.single_col_cert_common_name + single_col_cert_common_name=self.aggregator.single_col_cert_common_name, ) def check_request(self, request): @@ -112,20 +113,25 @@ def check_request(self, request): Request sent from a collaborator that requires validation """ # TODO improve this check. the sender name could be spoofed - check_is_in(request.header.sender, self.aggregator.authorized_cols, self.logger) + check_is_in( + request.header.sender, self.aggregator.authorized_cols, self.logger + ) # check that the message is for me check_equal(request.header.receiver, self.aggregator.uuid, self.logger) # check that the message is for my federation check_equal( - request.header.federation_uuid, self.aggregator.federation_uuid, self.logger) + request.header.federation_uuid, + self.aggregator.federation_uuid, + self.logger, + ) # check that we agree on the single cert common name check_equal( request.header.single_col_cert_common_name, self.aggregator.single_col_cert_common_name, - self.logger + self.logger, ) def SendTaskResults(self, request, context): # NOQA:N802 @@ -140,8 +146,8 @@ def SendTaskResults(self, request, context): # NOQA:N802 self.validate_collaborator(request, context) self.check_request(request) collaborator_name = request.header.sender - round_number = request.round_number, - next_step = request.next_step, + round_number = (request.round_number,) + next_step = (request.next_step,) execution_environment = request.execution_environment _ = self.aggregator.send_task_results( @@ -164,8 +170,7 @@ def GetTasks(self, request, context): # NOQA:N802 self.check_request(request) collaborator_name = request.header.sender - rn, f, ee, st, q = self.aggregator.get_tasks( - request.header.sender) + rn, f, ee, st, q = self.aggregator.get_tasks(request.header.sender) return aggregator_pb2.GetTasksResponse( header=self.get_header(collaborator_name), @@ -173,7 +178,7 @@ def GetTasks(self, request, context): # NOQA:N802 function_name=f, execution_environment=ee, sleep_time=st, - quit=q + quit=q, ) def CallCheckpoint(self, request, context): # NOQA:N802 @@ -202,34 +207,36 @@ def CallCheckpoint(self, request, context): # NOQA:N802 def get_server(self): """Return gRPC server.""" - self.server = server(ThreadPoolExecutor(max_workers=cpu_count()), - options=channel_options) + self.server = server( + ThreadPoolExecutor(max_workers=cpu_count()), options=channel_options + ) aggregator_pb2_grpc.add_AggregatorServicer_to_server(self, self.server) if not self.tls: self.logger.warn( - 'gRPC is running on insecure channel with TLS disabled.') + "gRPC is running on insecure channel with TLS disabled." + ) port = self.server.add_insecure_port(self.uri) - self.logger.info(f'Insecure port: {port}') + self.logger.info(f"Insecure port: {port}") else: - with open(self.private_key, 'rb') as f: + with open(self.private_key, "rb") as f: private_key_b = f.read() - with open(self.certificate, 'rb') as f: + with open(self.certificate, "rb") as f: certificate_b = f.read() - with open(self.root_certificate, 'rb') as f: + with open(self.root_certificate, "rb") as f: root_certificate_b = f.read() if self.disable_client_auth: - self.logger.warn('Client-side authentication is disabled.') + self.logger.warn("Client-side authentication is disabled.") self.server_credentials = ssl_server_credentials( ((private_key_b, certificate_b),), root_certificates=root_certificate_b, - require_client_auth=not self.disable_client_auth + require_client_auth=not self.disable_client_auth, ) self.server.add_secure_port(self.uri, self.server_credentials) @@ -240,7 +247,7 @@ def serve(self): """Start an aggregator gRPC service.""" self.get_server() - self.logger.info('Starting Aggregator gRPC Server') + self.logger.info("Starting Aggregator gRPC Server") self.server.start() self.is_server_started = True try: diff --git a/openfl/experimental/transport/grpc/exceptions.py b/openfl/experimental/transport/grpc/exceptions.py index 5bd19315c0..3af78b9a23 100644 --- a/openfl/experimental/transport/grpc/exceptions.py +++ b/openfl/experimental/transport/grpc/exceptions.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Exceptions that occur during service interaction.""" diff --git a/openfl/experimental/transport/grpc/grpc_channel_options.py b/openfl/experimental/transport/grpc/grpc_channel_options.py index 229dd45e51..6267f9ad41 100644 --- a/openfl/experimental/transport/grpc/grpc_channel_options.py +++ b/openfl/experimental/transport/grpc/grpc_channel_options.py @@ -1,11 +1,11 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -max_metadata_size = 32 * 2 ** 20 -max_message_length = 2 ** 30 +max_metadata_size = 32 * 2**20 +max_message_length = 2**30 channel_options = [ - ('grpc.max_metadata_size', max_metadata_size), - ('grpc.max_send_message_length', max_message_length), - ('grpc.max_receive_message_length', max_message_length) + ("grpc.max_metadata_size", max_metadata_size), + ("grpc.max_send_message_length", max_message_length), + ("grpc.max_receive_message_length", max_message_length), ] diff --git a/openfl/experimental/utilities/__init__.py b/openfl/experimental/utilities/__init__.py index 2272d1459a..1375a65f81 100644 --- a/openfl/experimental/utilities/__init__.py +++ b/openfl/experimental/utilities/__init__.py @@ -1,49 +1,28 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """openfl.experimental.utilities package.""" -from .metaflow_utils import MetaflowInterface -from .transitions import ( - should_transfer, - aggregator_to_collaborator, - collaborator_to_aggregator, -) -from .exceptions import ( - SerializationError, - ResourcesNotAvailableError, +from openfl.experimental.utilities.exceptions import ( ResourcesAllocationError, + ResourcesNotAvailableError, + SerializationError, ) -from .stream_redirect import ( - RedirectStdStreamBuffer, +from openfl.experimental.utilities.metaflow_utils import MetaflowInterface +from openfl.experimental.utilities.resources import get_number_of_gpus +from openfl.experimental.utilities.runtime_utils import ( + check_resource_allocation, + checkpoint, + filter_attributes, + generate_artifacts, + parse_attrs, +) +from openfl.experimental.utilities.stream_redirect import ( RedirectStdStream, + RedirectStdStreamBuffer, RedirectStdStreamContext, ) -from .resources import get_number_of_gpus -from .runtime_utils import ( - parse_attrs, - generate_artifacts, - filter_attributes, - checkpoint, - check_resource_allocation, +from openfl.experimental.utilities.transitions import ( + aggregator_to_collaborator, + collaborator_to_aggregator, + should_transfer, ) - - -__all__ = [ - "MetaflowInterface", - "should_transfer", - "aggregator_to_collaborator", - "collaborator_to_aggregator", - "SerializationError", - "ResourcesNotAvailableError", - "ResourcesAllocationError", - "RedirectStdStreamBuffer", - "RedirectStdStream", - "RedirectStdStreamContext", - "get_number_of_gpus", - "parse_attrs", - "generate_artifacts", - "filter_attributes", - "checkpoint", - "check_resource_allocation", -] diff --git a/openfl/experimental/utilities/exceptions.py b/openfl/experimental/utilities/exceptions.py index 12a307d271..caabaded18 100644 --- a/openfl/experimental/utilities/exceptions.py +++ b/openfl/experimental/utilities/exceptions.py @@ -1,19 +1,23 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 + class SerializationError(Exception): + def __init__(self, *args: object) -> None: super().__init__(*args) pass class ResourcesNotAvailableError(Exception): + def __init__(self, *args: object) -> None: super().__init__(*args) pass class ResourcesAllocationError(Exception): + def __init__(self, *args: object) -> None: super().__init__(*args) pass diff --git a/openfl/experimental/utilities/metaflow_utils.py b/openfl/experimental/utilities/metaflow_utils.py index 0d08f5265c..36dd72a2f6 100644 --- a/openfl/experimental/utilities/metaflow_utils.py +++ b/openfl/experimental/utilities/metaflow_utils.py @@ -1,70 +1,71 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """openfl.experimental.utilities.metaflow_utils module.""" from __future__ import annotations + +import ast +import fcntl +import hashlib from datetime import datetime -from metaflow.metaflow_environment import MetaflowEnvironment -from metaflow.plugins import LocalMetadataProvider -from metaflow.datastore import FlowDataStore, DATASTORES -from metaflow.graph import DAGNode, FlowGraph, StepVisitor -from metaflow.graph import deindent_docstring -from metaflow.datastore.task_datastore import TaskDataStore +from pathlib import Path + +# getsource only used to determine structure of FlowGraph +from typing import TYPE_CHECKING + +import cloudpickle as pickle +import ray +from dill.source import getsource # nosec +from metaflow.datastore import DATASTORES, FlowDataStore from metaflow.datastore.exceptions import ( DataException, UnpicklableArtifactException, ) -from metaflow.datastore.task_datastore import only_if_not_done, require_mode -import cloudpickle as pickle -import ray -import ast -from pathlib import Path -from metaflow.runtime import TruncatedBuffer, mflog_msg, MAX_LOG_SIZE +from metaflow.datastore.task_datastore import ( + TaskDataStore, + only_if_not_done, + require_mode, +) +from metaflow.graph import DAGNode, FlowGraph, StepVisitor, deindent_docstring +from metaflow.metaflow_environment import MetaflowEnvironment from metaflow.mflog import RUNTIME_LOG_SOURCE +from metaflow.plugins import LocalMetadataProvider +from metaflow.runtime import MAX_LOG_SIZE, TruncatedBuffer, mflog_msg from metaflow.task import MetaDatum -import fcntl -import hashlib -from dill.source import getsource # nosec -# getsource only used to determine structure of FlowGraph -from typing import TYPE_CHECKING + if TYPE_CHECKING: from openfl.experimental.interface import FLSpec + +import base64 +import json +import uuid from io import StringIO -from typing import Generator, Any, Type +from typing import Any, Generator, Type +from metaflow import __version__ as mf_version from metaflow.plugins.cards.card_modules.basic import ( - DefaultCard, - TaskInfoComponent, -) -from metaflow.plugins.cards.card_modules.basic import ( + CSS_PATH, + JS_PATH, + RENDER_TEMPLATE_PATH, DagComponent, - SectionComponent, + DefaultCard, PageComponent, -) -from metaflow.plugins.cards.card_modules.basic import ( - RENDER_TEMPLATE_PATH, - JS_PATH, - CSS_PATH, -) -from metaflow.plugins.cards.card_modules.basic import ( + SectionComponent, + TaskInfoComponent, read_file, transform_flow_graph, ) -from metaflow import __version__ as mf_version - -import json -import base64 -import uuid class SystemMutex: + def __init__(self, name): self.name = name def __enter__(self): - lock_id = hashlib.new('md5', self.name.encode("utf8"), - usedforsecurity=False).hexdigest() # nosec + lock_id = hashlib.new( + "md5", self.name.encode("utf8"), usedforsecurity=False + ).hexdigest() # nosec # MD5sum used for concurrency purposes, not security self.fp = open(f"/tmp/.lock-{lock_id}.lck", "wb") fcntl.flock(self.fp.fileno(), fcntl.LOCK_EX) @@ -75,6 +76,7 @@ def __exit__(self, _type, value, tb): class Flow: + def __init__(self, name): """Mock flow for metaflow internals""" self.name = name @@ -82,6 +84,7 @@ def __init__(self, name): @ray.remote class Counter(object): + def __init__(self): self.value = 0 @@ -94,6 +97,7 @@ def get_counter(self): class DAGnode(DAGNode): + def __init__(self, func_ast, decos, doc): self.name = func_ast.name self.func_lineno = func_ast.lineno @@ -186,6 +190,7 @@ def _parse(self, func_ast): class StepVisitor(StepVisitor): + def __init__(self, nodes, flow): super().__init__(nodes, flow) @@ -196,6 +201,7 @@ def visit_FunctionDef(self, node): # NOQA: N802 class FlowGraph(FlowGraph): + def __init__(self, flow): self.name = flow.__name__ self.nodes = self._create_nodes(flow) @@ -217,6 +223,7 @@ def _create_nodes(self, flow): class TaskDataStore(TaskDataStore): + def __init__( self, flow_datastore, @@ -321,6 +328,7 @@ def pickle_iter(): class FlowDataStore(FlowDataStore): + def __init__( self, flow_name, @@ -365,6 +373,7 @@ def get_task_datastore( class MetaflowInterface: + def __init__(self, flow: Type[FLSpec], backend: str = "ray"): """ Wrapper class for the metaflow tooling modified to work with the @@ -442,7 +451,7 @@ def save_artifacts( task_name: str, task_id: int, buffer_out: Type[StringIO], - buffer_err: Type[StringIO] + buffer_err: Type[StringIO], ) -> None: """ Use metaflow task datastore to save flow attributes, stdout, and stderr @@ -511,11 +520,11 @@ def load_artifacts(self, artifact_names, task_name, task_id): return task_datastore.load_artifacts(artifact_names) def emit_log( - self, - msgbuffer_out: Type[StringIO], - msgbuffer_err: Type[StringIO], - task_datastore: Type[TaskDataStore], - system_msg: bool = False + self, + msgbuffer_out: Type[StringIO], + msgbuffer_err: Type[StringIO], + task_datastore: Type[TaskDataStore], + system_msg: bool = False, ) -> None: """ This function writes the stdout and stderr to Metaflow TaskDatastore diff --git a/openfl/experimental/utilities/resources.py b/openfl/experimental/utilities/resources.py index 24689bb82e..6c0ed54c3c 100644 --- a/openfl/experimental/utilities/resources.py +++ b/openfl/experimental/utilities/resources.py @@ -3,7 +3,7 @@ """openfl.experimental.utilities.resources module.""" from logging import getLogger -from subprocess import run, PIPE +from subprocess import PIPE, run logger = getLogger(__name__) @@ -24,6 +24,8 @@ def get_number_of_gpus() -> int: stdout = op.stdout.decode().strip() return len(stdout.split("\n")) except FileNotFoundError: - logger.warning(f'No GPUs found! If this is a mistake please try running "{command}" ' - + 'manually.') + logger.warning( + f'No GPUs found! If this is a mistake please try running "{command}" ' + + "manually." + ) return 0 diff --git a/openfl/experimental/utilities/runtime_utils.py b/openfl/experimental/utilities/runtime_utils.py index e39ed9f36d..3421c5d211 100644 --- a/openfl/experimental/utilities/runtime_utils.py +++ b/openfl/experimental/utilities/runtime_utils.py @@ -1,12 +1,13 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """openfl.experimental.utilities package.""" -import itertools import inspect -import numpy as np +import itertools from types import MethodType + +import numpy as np + from openfl.experimental.utilities import ResourcesAllocationError diff --git a/openfl/experimental/utilities/stream_redirect.py b/openfl/experimental/utilities/stream_redirect.py index daa2b7bc78..5f7a25fd3d 100644 --- a/openfl/experimental/utilities/stream_redirect.py +++ b/openfl/experimental/utilities/stream_redirect.py @@ -1,10 +1,9 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """openfl.experimental.utilities.stream_redirect module.""" -import sys import io +import sys from copy import deepcopy @@ -56,6 +55,7 @@ class RedirectStdStreamContext: """ Context Manager that enables redirection of stdout & stderr """ + def __init__(self): self.stdstreambuffer = RedirectStdStreamBuffer() @@ -65,8 +65,12 @@ def __enter__(self): """ self.__old_stdout = sys.stdout self.__old_stderr = sys.stderr - sys.stdout = RedirectStdStream(self.stdstreambuffer._stdoutbuff, sys.stdout) - sys.stderr = RedirectStdStream(self.stdstreambuffer._stderrbuff, sys.stderr) + sys.stdout = RedirectStdStream( + self.stdstreambuffer._stdoutbuff, sys.stdout + ) + sys.stderr = RedirectStdStream( + self.stdstreambuffer._stderrbuff, sys.stderr + ) return self.stdstreambuffer diff --git a/openfl/experimental/utilities/transitions.py b/openfl/experimental/utilities/transitions.py index a64a8c38dc..b134a73690 100644 --- a/openfl/experimental/utilities/transitions.py +++ b/openfl/experimental/utilities/transitions.py @@ -1,6 +1,5 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Detect criteria for transitions in placement.""" diff --git a/openfl/experimental/utilities/ui.py b/openfl/experimental/utilities/ui.py index 983060be04..ae10910ffd 100644 --- a/openfl/experimental/utilities/ui.py +++ b/openfl/experimental/utilities/ui.py @@ -1,13 +1,15 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from openfl.experimental.utilities.metaflow_utils import DefaultCard, FlowGraph -from pathlib import Path import os import webbrowser +from pathlib import Path + +from openfl.experimental.utilities.metaflow_utils import DefaultCard, FlowGraph class InspectFlow: + def __init__( self, flow_obj, diff --git a/openfl/experimental/workspace_export/__init__.py b/openfl/experimental/workspace_export/__init__.py index ba88041c78..11ac57f2b6 100644 --- a/openfl/experimental/workspace_export/__init__.py +++ b/openfl/experimental/workspace_export/__init__.py @@ -1,6 +1,4 @@ # Copyright (C) 2020-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from .export import WorkspaceExport - -__all__ = ["WorkspaceExport"] +from openfl.experimental.workspace_export.export import WorkspaceExport diff --git a/openfl/experimental/workspace_export/export.py b/openfl/experimental/workspace_export/export.py index ad338a5906..16490324d4 100644 --- a/openfl/experimental/workspace_export/export.py +++ b/openfl/experimental/workspace_export/export.py @@ -2,19 +2,19 @@ # SPDX-License-Identifier: Apache-2.0 """Workspace Builder module.""" -import re -import yaml import ast -import astor -import inspect import importlib -import nbformat - -from shutil import copytree +import inspect +import re from logging import getLogger from pathlib import Path +from shutil import copytree +import astor +import nbformat +import yaml from nbdev.export import nb_export + from openfl.experimental.interface.cli.cli_helper import print_tree @@ -30,29 +30,39 @@ class WorkspaceExport: Returns: None """ - def __init__(self, - notebook_path: str, - output_workspace: str) -> None: + + def __init__(self, notebook_path: str, output_workspace: str) -> None: self.logger = getLogger(__name__) self.notebook_path = Path(notebook_path).resolve() self.output_workspace_path = Path(output_workspace).resolve() self.output_workspace_path.parent.mkdir(parents=True, exist_ok=True) - self.template_workspace_path = Path(f"{__file__}").parent.parent.parent.parent.joinpath( - "openfl-workspace", "experimental", "template_workspace" - ).resolve(strict=True) + self.template_workspace_path = ( + Path(f"{__file__}") + .parent.parent.parent.parent.joinpath( + "openfl-workspace", "experimental", "template_workspace" + ) + .resolve(strict=True) + ) # Copy template workspace to output directory - self.created_workspace_path = Path(copytree( - self.template_workspace_path, self.output_workspace_path)) - self.logger.info(f"Copied template workspace to {self.created_workspace_path}") + self.created_workspace_path = Path( + copytree(self.template_workspace_path, self.output_workspace_path) + ) + self.logger.info( + f"Copied template workspace to {self.created_workspace_path}" + ) self.logger.info("Converting jupter notebook to python script...") export_filename = self.__get_exp_name() - self.script_path = Path(self.__convert_to_python( - self.notebook_path, self.created_workspace_path.joinpath("src"), - f"{export_filename}.py")).resolve() + self.script_path = Path( + self.__convert_to_python( + self.notebook_path, + self.created_workspace_path.joinpath("src"), + f"{export_filename}.py", + ) + ).resolve() print_tree(self.created_workspace_path, level=2) # Generated python script name without .py extension @@ -72,11 +82,15 @@ def __get_exp_name(self): code = cell.source match = re.search(r"#\s*\|\s*default_exp\s+(\w+)", code) if match: - self.logger.info(f"Retrieved {match.group(1)} from default_exp") + self.logger.info( + f"Retrieved {match.group(1)} from default_exp" + ) return match.group(1) return None - def __convert_to_python(self, notebook_path: Path, output_path: Path, export_filename): + def __convert_to_python( + self, notebook_path: Path, output_path: Path, export_filename + ): nb_export(notebook_path, output_path) return Path(output_path).joinpath(export_filename).resolve() @@ -119,8 +133,10 @@ def __get_class_arguments(self, class_name): # Find class from imported python script module for idx, attr in enumerate(self.available_modules_in_exported_script): if attr == class_name: - cls = getattr(self.exported_script_module, - self.available_modules_in_exported_script[idx]) + cls = getattr( + self.exported_script_module, + self.available_modules_in_exported_script[idx], + ) # If class not found if "cls" not in locals(): @@ -131,8 +147,11 @@ def __get_class_arguments(self, class_name): if "__init__" in cls.__dict__: init_signature = inspect.signature(cls.__init__) # Extract the parameter names (excluding 'self', 'args', and 'kwargs') - arg_names = [param for param in init_signature.parameters if param not in ( - "self", "args", "kwargs")] + arg_names = [ + param + for param in init_signature.parameters + if param not in ("self", "args", "kwargs") + ] return arg_names return [] self.logger.error(f"{cls} is not a class") @@ -148,7 +167,11 @@ def __get_class_name_and_sourcecode_from_parent_class(self, parent_class): # Going though all attributes in imported python script for attr in self.available_modules_in_exported_script: t = getattr(self.exported_script_module, attr) - if inspect.isclass(t) and t != parent_class and issubclass(t, parent_class): + if ( + inspect.isclass(t) + and t != parent_class + and issubclass(t, parent_class) + ): return inspect.getsource(t), attr return None, None @@ -157,15 +180,15 @@ def __extract_class_initializing_args(self, class_name): """ Provided name of the class returns expected arguments and it's values in form of dictionary """ - instantiation_args = { - "args": {}, "kwargs": {} - } + instantiation_args = {"args": {}, "kwargs": {}} with open(self.script_path, "r") as s: tree = ast.parse(s.read()) for node in ast.walk(tree): - if isinstance(node, ast.Call) and isinstance(node.func, ast.Name): + if isinstance(node, ast.Call) and isinstance( + node.func, ast.Name + ): if node.func.id == class_name: # We found an instantiation of the class for arg in node.args: @@ -174,9 +197,13 @@ def __extract_class_initializing_args(self, class_name): # Use the variable name as the argument value instantiation_args["args"][arg.id] = arg.id elif isinstance(arg, ast.Constant): - instantiation_args["args"][arg.s] = astor.to_source(arg) + instantiation_args["args"][arg.s] = ( + astor.to_source(arg) + ) else: - instantiation_args["args"][arg.arg] = astor.to_source(arg).strip() + instantiation_args["args"][arg.arg] = ( + astor.to_source(arg).strip() + ) for kwarg in node.keywords: # Iterate through keyword arguments @@ -200,12 +227,14 @@ def __import_exported_script(self): """ Imports generated python script with help of importlib """ - import sys import importlib + import sys sys.path.append(str(self.script_path.parent)) self.exported_script_module = importlib.import_module(self.script_name) - self.available_modules_in_exported_script = dir(self.exported_script_module) + self.available_modules_in_exported_script = dir( + self.exported_script_module + ) def __read_yaml(self, path): with open(path, "r") as y: @@ -252,11 +281,16 @@ def generate_requirements(self): line_nos.append(i) # Avoid commented lines, libraries from *.txt file, or openfl.git # installation - if not line.startswith("#") and "-r" not in line and "openfl.git" not in line: + if ( + not line.startswith("#") + and "-r" not in line + and "openfl.git" not in line + ): requirements.append(f"{line.split(' ')[-1].strip()}\n") requirements_filepath = str( - self.created_workspace_path.joinpath("requirements.txt").resolve()) + self.created_workspace_path.joinpath("requirements.txt").resolve() + ) # Write libraries found in requirements.txt with open(requirements_filepath, "a") as f: @@ -277,22 +311,28 @@ def generate_plan_yaml(self): importlib.import_module("openfl.experimental.interface"), "FLSpec" ) # Get flow classname - _, self.flow_class_name = self.__get_class_name_and_sourcecode_from_parent_class(flspec) + _, self.flow_class_name = ( + self.__get_class_name_and_sourcecode_from_parent_class(flspec) + ) # Get expected arguments of flow class - self.flow_class_expected_arguments = self.__get_class_arguments(self.flow_class_name) + self.flow_class_expected_arguments = self.__get_class_arguments( + self.flow_class_name + ) # Get provided arguments to flow class - self.arguments_passed_to_initialize = self.__extract_class_initializing_args( - self.flow_class_name) + self.arguments_passed_to_initialize = ( + self.__extract_class_initializing_args(self.flow_class_name) + ) - plan = self.created_workspace_path.joinpath("plan", "plan.yaml").resolve() + plan = self.created_workspace_path.joinpath( + "plan", "plan.yaml" + ).resolve() data = self.__read_yaml(plan) if data is None: - data["federated_flow"] = { - "settings": {}, - "template": "" - } + data["federated_flow"] = {"settings": {}, "template": ""} - data["federated_flow"]["template"] = f"src.{self.script_name}.{self.flow_class_name}" + data["federated_flow"][ + "template" + ] = f"src.{self.script_name}.{self.flow_class_name}" def update_dictionary(args: dict, data: dict, dtype: str = "args"): for idx, (k, v) in enumerate(args.items()): @@ -304,9 +344,7 @@ def update_dictionary(args: dict, data: dict, dtype: str = "args"): elif dtype == "kwargs": if v is not None and type(v) not in (int, str, bool): v = f"src.{self.script_name}.{k}" - data["federated_flow"]["settings"].update({ - k: v - }) + data["federated_flow"]["settings"].update({k: v}) # Find positional arguments of flow class and it's values pos_args = self.arguments_passed_to_initialize["args"] @@ -328,26 +366,36 @@ def generate_data_yaml(self): # If flow classname is not yet found if not hasattr(self, "flow_class_name"): flspec = getattr( - importlib.import_module("openfl.experimental.interface"), "FLSpec" + importlib.import_module("openfl.experimental.interface"), + "FLSpec", + ) + _, self.flow_class_name = ( + self.__get_class_name_and_sourcecode_from_parent_class(flspec) ) - _, self.flow_class_name = self.__get_class_name_and_sourcecode_from_parent_class( - flspec) # Import flow class - federated_flow_class = getattr(self.exported_script_module, self.flow_class_name) + federated_flow_class = getattr( + self.exported_script_module, self.flow_class_name + ) # Find federated_flow._runtime and federated_flow._runtime.collaborators for t in self.available_modules_in_exported_script: t = getattr(self.exported_script_module, t) if isinstance(t, federated_flow_class): if not hasattr(t, "_runtime"): - raise Exception("Unable to locate LocalRuntime instantiation") + raise Exception( + "Unable to locate LocalRuntime instantiation" + ) runtime = t._runtime if not hasattr(runtime, "collaborators"): - raise Exception("LocalRuntime instance does not have collaborators") + raise Exception( + "LocalRuntime instance does not have collaborators" + ) collaborators_names = runtime.collaborators break - data_yaml = self.created_workspace_path.joinpath("plan", "data.yaml").resolve() + data_yaml = self.created_workspace_path.joinpath( + "plan", "data.yaml" + ).resolve() data = self.__read_yaml(data_yaml) if data is None: data = {} @@ -359,12 +407,13 @@ def generate_data_yaml(self): data["aggregator"] = { "callable_func": { "settings": {}, - "template": f"src.{self.script_name}.{private_attrs_callable.__name__}" + "template": f"src.{self.script_name}.{private_attrs_callable.__name__}", } } # Find arguments expected by Aggregator - arguments_passed_to_initialize = self.__extract_class_initializing_args("Aggregator")[ - "kwargs"] + arguments_passed_to_initialize = ( + self.__extract_class_initializing_args("Aggregator")["kwargs"] + ) agg_kwargs = aggregator.kwargs for key, value in agg_kwargs.items(): if isinstance(value, (int, str, bool)): @@ -375,15 +424,13 @@ def generate_data_yaml(self): data["aggregator"]["callable_func"]["settings"][key] = value # Find arguments expected by Collaborator - arguments_passed_to_initialize = self.__extract_class_initializing_args("Collaborator")[ - "kwargs"] + arguments_passed_to_initialize = self.__extract_class_initializing_args( + "Collaborator" + )["kwargs"] for collab_name in collaborators_names: if collab_name not in data: data[collab_name] = { - "callable_func": { - "settings": {}, - "template": None - } + "callable_func": {"settings": {}, "template": None} } # Find collaborator details kw_args = runtime.get_collaborator_kwargs(collab_name) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..614c3b6243 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,7 @@ +[tool.black] +line-length = 80 + +[tool.isort] +profile = "black" +force_single_line = "False" +line_length = 80 \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 7713a3e629..466a7d1c1c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,9 +1,17 @@ [flake8] -# W503 Line break occurred before a binary operator. Update by W504 Line -# break occurred after a binary operator -# N812: lowercase imported as non lowercase. Allow "import torch.nn.functional as F" -ignore = W503, N812 +ignore = + # Conflicts with black + E203 + # Line break occurred before a binary operator. Update by W504 Line + W503 + # Allow "import torch.nn.functional as F" + N812 + +per-file-ignores = + # Unused imports in __init__.py are OK + **/__init__.py:F401 + select = E,F,W,N,C4,C90,C801 inline-quotes = ' multiline-quotes = ' diff --git a/shell/format.sh b/shell/format.sh new file mode 100755 index 0000000000..36f863dbc2 --- /dev/null +++ b/shell/format.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -Eeuo pipefail + +base_dir=$(dirname $(dirname $0)) + +# TODO: @karansh1 Apply across all modules +isort --sp "${base_dir}/pyproject.toml" openfl/experimental + +black --config "${base_dir}/pyproject.toml" openfl/experimental + +flake8 --config "${base_dir}/setup.cfg" openfl/experimental \ No newline at end of file diff --git a/shell/lint.sh b/shell/lint.sh new file mode 100755 index 0000000000..16d5da0ef4 --- /dev/null +++ b/shell/lint.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -Eeuo pipefail + +base_dir=$(dirname $(dirname $0)) + +# TODO: @karansh1 Apply across all modules +isort --sp "${base_dir}/pyproject.toml" --check openfl/experimental + +black --config "${base_dir}/pyproject.toml" --check openfl/experimental + +flake8 --config "${base_dir}/setup.cfg" openfl/experimental \ No newline at end of file