-
Notifications
You must be signed in to change notification settings - Fork 906
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for middleware layers (#2580)
Co-authored-by: Daniel J. Beutel <[email protected]> Co-authored-by: Charles Beauville <[email protected]>
- Loading branch information
1 parent
9507159
commit 251fdf1
Showing
7 changed files
with
261 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
Use Built-in Middleware Layers | ||
============================== | ||
|
||
**Note: This tutorial covers experimental features. The functionality and interfaces may change in future versions.** | ||
|
||
In this tutorial, we will learn how to utilize built-in middleware layers to augment the behavior of a ``FlowerCallable``. Middleware allows us to perform operations before and after a task is processed in the ``FlowerCallable``. | ||
|
||
What is middleware? | ||
------------------- | ||
|
||
Middleware is a callable that wraps around a ``FlowerCallable``. It can manipulate or inspect incoming tasks (``TaskIns``) in the ``Fwd`` and the resulting tasks (``TaskRes``) in the ``Bwd``. The signature for a middleware layer (``Layer``) is as follows: | ||
|
||
.. code-block:: python | ||
FlowerCallable = Callable[[Fwd], Bwd] | ||
Layer = Callable[[Fwd, FlowerCallable], Bwd] | ||
A typical middleware function might look something like this: | ||
|
||
.. code-block:: python | ||
def example_middleware(fwd: Fwd, ffn: FlowerCallable) -> Bwd: | ||
# Do something with Fwd before passing to the inner ``FlowerCallable``. | ||
bwd = ffn(fwd) | ||
# Do something with Bwd before returning. | ||
return bwd | ||
Using middleware layers | ||
----------------------- | ||
|
||
To use middleware layers in your ``FlowerCallable``, you can follow these steps: | ||
|
||
1. Import the required middleware | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
First, import the built-in middleware layers you intend to use: | ||
|
||
.. code-block:: python | ||
import flwr as fl | ||
from flwr.client.middleware import example_middleware1, example_middleware2 | ||
2. Define your client function | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
Define your client function (``client_fn``) that will be wrapped by the middleware: | ||
|
||
.. code-block:: python | ||
def client_fn(cid): | ||
# Your client code goes here. | ||
return # your client | ||
3. Create the ``FlowerCallable`` with middleware | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
Create your ``FlowerCallable`` and pass the middleware layers as a list to the ``layers`` argument. The order in which you provide the middleware layers matters: | ||
|
||
.. code-block:: python | ||
flower = fl.app.Flower( | ||
client_fn=client_fn, | ||
layers=[ | ||
example_middleware1, # Middleware layer 1 | ||
example_middleware2, # Middleware layer 2 | ||
] | ||
) | ||
Order of execution | ||
------------------ | ||
|
||
When the ``FlowerCallable`` runs, the middleware layers are executed in the order they are provided in the list: | ||
|
||
1. ``example_middleware1`` (outermost layer) | ||
2. ``example_middleware2`` (next layer) | ||
3. Message handler (core function that handles ``TaskIns`` and returns ``TaskRes``) | ||
4. ``example_middleware2`` (on the way back) | ||
5. ``example_middleware1`` (outermost layer on the way back) | ||
|
||
Each middleware has a chance to inspect and modify the ``TaskIns`` in the ``Fwd`` before passing it to the next layer, and likewise with the ``TaskRes`` in the ``Bwd`` before returning it up the stack. | ||
|
||
Conclusion | ||
---------- | ||
|
||
By following this guide, you have learned how to effectively use middleware layers to enhance your ``FlowerCallable``'s functionality. Remember that the order of middleware is crucial and affects how the input and output are processed. | ||
|
||
Enjoy building more robust and flexible ``FlowerCallable``s with middleware layers! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright 2023 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Middleware layers.""" | ||
|
||
|
||
from .utils import make_ffn | ||
|
||
__all__ = [ | ||
"make_ffn", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Copyright 2023 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Utility functions for middleware layers.""" | ||
|
||
|
||
from typing import List | ||
|
||
from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer | ||
|
||
|
||
def make_ffn(ffn: FlowerCallable, layers: List[Layer]) -> FlowerCallable: | ||
""".""" | ||
|
||
def wrap_ffn(_ffn: FlowerCallable, _layer: Layer) -> FlowerCallable: | ||
def new_ffn(fwd: Fwd) -> Bwd: | ||
return _layer(fwd, _ffn) | ||
|
||
return new_ffn | ||
|
||
for layer in reversed(layers): | ||
ffn = wrap_ffn(ffn, layer) | ||
|
||
return ffn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright 2023 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for the utility functions.""" | ||
|
||
|
||
import unittest | ||
from typing import List | ||
|
||
from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer | ||
from flwr.client.workload_state import WorkloadState | ||
from flwr.proto.task_pb2 import TaskIns, TaskRes | ||
|
||
from .utils import make_ffn | ||
|
||
|
||
def make_mock_middleware(name: str, footprint: List[str]) -> Layer: | ||
"""Make a mock middleware layer.""" | ||
|
||
def middleware(fwd: Fwd, app: FlowerCallable) -> Bwd: | ||
footprint.append(name) | ||
fwd.task_ins.task_id += f"{name}" | ||
bwd = app(fwd) | ||
footprint.append(name) | ||
bwd.task_res.task_id += f"{name}" | ||
return bwd | ||
|
||
return middleware | ||
|
||
|
||
def make_mock_app(name: str, footprint: List[str]) -> FlowerCallable: | ||
"""Make a mock app.""" | ||
|
||
def app(fwd: Fwd) -> Bwd: | ||
footprint.append(name) | ||
fwd.task_ins.task_id += f"{name}" | ||
return Bwd(task_res=TaskRes(task_id=name), state=WorkloadState({})) | ||
|
||
return app | ||
|
||
|
||
class TestMakeApp(unittest.TestCase): | ||
"""Tests for the `make_app` function.""" | ||
|
||
def test_multiple_middlewares(self) -> None: | ||
"""Test if multiple middlewares are called in the correct order.""" | ||
# Prepare | ||
footprint: List[str] = [] | ||
mock_app = make_mock_app("app", footprint) | ||
mock_middleware_names = [f"middleware{i}" for i in range(1, 15)] | ||
mock_middleware_layers = [ | ||
make_mock_middleware(name, footprint) for name in mock_middleware_names | ||
] | ||
task_ins = TaskIns() | ||
|
||
# Execute | ||
wrapped_app = make_ffn(mock_app, mock_middleware_layers) | ||
task_res = wrapped_app(Fwd(task_ins=task_ins, state=WorkloadState({}))).task_res | ||
|
||
# Assert | ||
trace = mock_middleware_names + ["app"] | ||
self.assertEqual(footprint, trace + list(reversed(mock_middleware_names))) | ||
# pylint: disable-next=no-member | ||
self.assertEqual(task_ins.task_id, "".join(trace)) | ||
self.assertEqual(task_res.task_id, "".join(reversed(trace))) | ||
|
||
def test_filter(self) -> None: | ||
"""Test if a middleware can filter incoming TaskIns.""" | ||
# Prepare | ||
footprint: List[str] = [] | ||
mock_app = make_mock_app("app", footprint) | ||
task_ins = TaskIns() | ||
|
||
def filter_layer(fwd: Fwd, _: FlowerCallable) -> Bwd: | ||
footprint.append("filter") | ||
fwd.task_ins.task_id += "filter" | ||
# Skip calling app | ||
return Bwd(task_res=TaskRes(task_id="filter"), state=WorkloadState({})) | ||
|
||
# Execute | ||
wrapped_app = make_ffn(mock_app, [filter_layer]) | ||
task_res = wrapped_app(Fwd(task_ins=task_ins, state=WorkloadState({}))).task_res | ||
|
||
# Assert | ||
self.assertEqual(footprint, ["filter"]) | ||
# pylint: disable-next=no-member | ||
self.assertEqual(task_ins.task_id, "filter") | ||
self.assertEqual(task_res.task_id, "filter") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters