Skip to content

Commit

Permalink
Add support for middleware layers (#2580)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <[email protected]>
Co-authored-by: Charles Beauville <[email protected]>
  • Loading branch information
3 people authored Dec 19, 2023
1 parent 9507159 commit 251fdf1
Show file tree
Hide file tree
Showing 7 changed files with 261 additions and 13 deletions.
87 changes: 87 additions & 0 deletions doc/source/how-to-use-built-in-middleware-layers.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
Use Built-in Middleware Layers
==============================

**Note: This tutorial covers experimental features. The functionality and interfaces may change in future versions.**

In this tutorial, we will learn how to utilize built-in middleware layers to augment the behavior of a ``FlowerCallable``. Middleware allows us to perform operations before and after a task is processed in the ``FlowerCallable``.

What is middleware?
-------------------

Middleware is a callable that wraps around a ``FlowerCallable``. It can manipulate or inspect incoming tasks (``TaskIns``) in the ``Fwd`` and the resulting tasks (``TaskRes``) in the ``Bwd``. The signature for a middleware layer (``Layer``) is as follows:

.. code-block:: python
FlowerCallable = Callable[[Fwd], Bwd]
Layer = Callable[[Fwd, FlowerCallable], Bwd]
A typical middleware function might look something like this:

.. code-block:: python
def example_middleware(fwd: Fwd, ffn: FlowerCallable) -> Bwd:
# Do something with Fwd before passing to the inner ``FlowerCallable``.
bwd = ffn(fwd)
# Do something with Bwd before returning.
return bwd
Using middleware layers
-----------------------

To use middleware layers in your ``FlowerCallable``, you can follow these steps:

1. Import the required middleware
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

First, import the built-in middleware layers you intend to use:

.. code-block:: python
import flwr as fl
from flwr.client.middleware import example_middleware1, example_middleware2
2. Define your client function
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Define your client function (``client_fn``) that will be wrapped by the middleware:

.. code-block:: python
def client_fn(cid):
# Your client code goes here.
return # your client
3. Create the ``FlowerCallable`` with middleware
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Create your ``FlowerCallable`` and pass the middleware layers as a list to the ``layers`` argument. The order in which you provide the middleware layers matters:

.. code-block:: python
flower = fl.app.Flower(
client_fn=client_fn,
layers=[
example_middleware1, # Middleware layer 1
example_middleware2, # Middleware layer 2
]
)
Order of execution
------------------

When the ``FlowerCallable`` runs, the middleware layers are executed in the order they are provided in the list:

1. ``example_middleware1`` (outermost layer)
2. ``example_middleware2`` (next layer)
3. Message handler (core function that handles ``TaskIns`` and returns ``TaskRes``)
4. ``example_middleware2`` (on the way back)
5. ``example_middleware1`` (outermost layer on the way back)

Each middleware has a chance to inspect and modify the ``TaskIns`` in the ``Fwd`` before passing it to the next layer, and likewise with the ``TaskRes`` in the ``Bwd`` before returning it up the stack.

Conclusion
----------

By following this guide, you have learned how to effectively use middleware layers to enhance your ``FlowerCallable``'s functionality. Remember that the order of middleware is crucial and affects how the input and output are processed.

Enjoy building more robust and flexible ``FlowerCallable``s with middleware layers!
1 change: 1 addition & 0 deletions doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ Problem-oriented how-to guides show step-by-step how to achieve a specific goal.
how-to-configure-logging
how-to-enable-ssl-connections
how-to-upgrade-to-flower-1.0
how-to-use-built-in-middleware-layers

.. toctree::
:maxdepth: 1
Expand Down
29 changes: 16 additions & 13 deletions src/py/flwr/client/flower.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@


import importlib
from typing import cast
from typing import List, Optional, cast

from flwr.client.message_handler.message_handler import handle
from flwr.client.typing import Bwd, ClientFn, Fwd
from flwr.client.middleware.utils import make_ffn
from flwr.client.typing import Bwd, ClientFn, Fwd, Layer


class Flower:
Expand Down Expand Up @@ -51,21 +52,23 @@ class Flower:
def __init__(
self,
client_fn: ClientFn, # Only for backward compatibility
layers: Optional[List[Layer]] = None,
) -> None:
self.client_fn = client_fn
# Create wrapper function for `handle`
def ffn(fwd: Fwd) -> Bwd: # pylint: disable=invalid-name
task_res, state_updated = handle(
client_fn=client_fn,
state=fwd.state,
task_ins=fwd.task_ins,
)
return Bwd(task_res=task_res, state=state_updated)

# Wrap middleware layers around the wrapped handle function
self._call = make_ffn(ffn, layers if layers is not None else [])

def __call__(self, fwd: Fwd) -> Bwd:
"""."""
# Execute the task
task_res, state_updated = handle(
client_fn=self.client_fn,
state=fwd.state,
task_ins=fwd.task_ins,
)
return Bwd(
task_res=task_res,
state=state_updated,
)
return self._call(fwd)


class LoadCallableError(Exception):
Expand Down
22 changes: 22 additions & 0 deletions src/py/flwr/client/middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Middleware layers."""


from .utils import make_ffn

__all__ = [
"make_ffn",
]
35 changes: 35 additions & 0 deletions src/py/flwr/client/middleware/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for middleware layers."""


from typing import List

from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer


def make_ffn(ffn: FlowerCallable, layers: List[Layer]) -> FlowerCallable:
"""."""

def wrap_ffn(_ffn: FlowerCallable, _layer: Layer) -> FlowerCallable:
def new_ffn(fwd: Fwd) -> Bwd:
return _layer(fwd, _ffn)

return new_ffn

for layer in reversed(layers):
ffn = wrap_ffn(ffn, layer)

return ffn
99 changes: 99 additions & 0 deletions src/py/flwr/client/middleware/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the utility functions."""


import unittest
from typing import List

from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer
from flwr.client.workload_state import WorkloadState
from flwr.proto.task_pb2 import TaskIns, TaskRes

from .utils import make_ffn


def make_mock_middleware(name: str, footprint: List[str]) -> Layer:
"""Make a mock middleware layer."""

def middleware(fwd: Fwd, app: FlowerCallable) -> Bwd:
footprint.append(name)
fwd.task_ins.task_id += f"{name}"
bwd = app(fwd)
footprint.append(name)
bwd.task_res.task_id += f"{name}"
return bwd

return middleware


def make_mock_app(name: str, footprint: List[str]) -> FlowerCallable:
"""Make a mock app."""

def app(fwd: Fwd) -> Bwd:
footprint.append(name)
fwd.task_ins.task_id += f"{name}"
return Bwd(task_res=TaskRes(task_id=name), state=WorkloadState({}))

return app


class TestMakeApp(unittest.TestCase):
"""Tests for the `make_app` function."""

def test_multiple_middlewares(self) -> None:
"""Test if multiple middlewares are called in the correct order."""
# Prepare
footprint: List[str] = []
mock_app = make_mock_app("app", footprint)
mock_middleware_names = [f"middleware{i}" for i in range(1, 15)]
mock_middleware_layers = [
make_mock_middleware(name, footprint) for name in mock_middleware_names
]
task_ins = TaskIns()

# Execute
wrapped_app = make_ffn(mock_app, mock_middleware_layers)
task_res = wrapped_app(Fwd(task_ins=task_ins, state=WorkloadState({}))).task_res

# Assert
trace = mock_middleware_names + ["app"]
self.assertEqual(footprint, trace + list(reversed(mock_middleware_names)))
# pylint: disable-next=no-member
self.assertEqual(task_ins.task_id, "".join(trace))
self.assertEqual(task_res.task_id, "".join(reversed(trace)))

def test_filter(self) -> None:
"""Test if a middleware can filter incoming TaskIns."""
# Prepare
footprint: List[str] = []
mock_app = make_mock_app("app", footprint)
task_ins = TaskIns()

def filter_layer(fwd: Fwd, _: FlowerCallable) -> Bwd:
footprint.append("filter")
fwd.task_ins.task_id += "filter"
# Skip calling app
return Bwd(task_res=TaskRes(task_id="filter"), state=WorkloadState({}))

# Execute
wrapped_app = make_ffn(mock_app, [filter_layer])
task_res = wrapped_app(Fwd(task_ins=task_ins, state=WorkloadState({}))).task_res

# Assert
self.assertEqual(footprint, ["filter"])
# pylint: disable-next=no-member
self.assertEqual(task_ins.task_id, "filter")
self.assertEqual(task_res.task_id, "filter")
1 change: 1 addition & 0 deletions src/py/flwr/client/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ class Bwd:

FlowerCallable = Callable[[Fwd], Bwd]
ClientFn = Callable[[str], Client]
Layer = Callable[[Fwd, FlowerCallable], Bwd]

0 comments on commit 251fdf1

Please sign in to comment.