diff --git a/doc/source/how-to-use-built-in-middleware-layers.rst b/doc/source/how-to-use-built-in-middleware-layers.rst new file mode 100644 index 000000000000..2e91623b26be --- /dev/null +++ b/doc/source/how-to-use-built-in-middleware-layers.rst @@ -0,0 +1,87 @@ +Use Built-in Middleware Layers +============================== + +**Note: This tutorial covers experimental features. The functionality and interfaces may change in future versions.** + +In this tutorial, we will learn how to utilize built-in middleware layers to augment the behavior of a ``FlowerCallable``. Middleware allows us to perform operations before and after a task is processed in the ``FlowerCallable``. + +What is middleware? +------------------- + +Middleware is a callable that wraps around a ``FlowerCallable``. It can manipulate or inspect incoming tasks (``TaskIns``) in the ``Fwd`` and the resulting tasks (``TaskRes``) in the ``Bwd``. The signature for a middleware layer (``Layer``) is as follows: + +.. code-block:: python + + FlowerCallable = Callable[[Fwd], Bwd] + Layer = Callable[[Fwd, FlowerCallable], Bwd] + +A typical middleware function might look something like this: + +.. code-block:: python + + def example_middleware(fwd: Fwd, ffn: FlowerCallable) -> Bwd: + # Do something with Fwd before passing to the inner ``FlowerCallable``. + bwd = ffn(fwd) + # Do something with Bwd before returning. + return bwd + +Using middleware layers +----------------------- + +To use middleware layers in your ``FlowerCallable``, you can follow these steps: + +1. Import the required middleware +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +First, import the built-in middleware layers you intend to use: + +.. code-block:: python + + import flwr as fl + from flwr.client.middleware import example_middleware1, example_middleware2 + +2. Define your client function +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Define your client function (``client_fn``) that will be wrapped by the middleware: + +.. code-block:: python + + def client_fn(cid): + # Your client code goes here. + return # your client + +3. Create the ``FlowerCallable`` with middleware +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Create your ``FlowerCallable`` and pass the middleware layers as a list to the ``layers`` argument. The order in which you provide the middleware layers matters: + +.. code-block:: python + + flower = fl.app.Flower( + client_fn=client_fn, + layers=[ + example_middleware1, # Middleware layer 1 + example_middleware2, # Middleware layer 2 + ] + ) + +Order of execution +------------------ + +When the ``FlowerCallable`` runs, the middleware layers are executed in the order they are provided in the list: + +1. ``example_middleware1`` (outermost layer) +2. ``example_middleware2`` (next layer) +3. Message handler (core function that handles ``TaskIns`` and returns ``TaskRes``) +4. ``example_middleware2`` (on the way back) +5. ``example_middleware1`` (outermost layer on the way back) + +Each middleware has a chance to inspect and modify the ``TaskIns`` in the ``Fwd`` before passing it to the next layer, and likewise with the ``TaskRes`` in the ``Bwd`` before returning it up the stack. + +Conclusion +---------- + +By following this guide, you have learned how to effectively use middleware layers to enhance your ``FlowerCallable``'s functionality. Remember that the order of middleware is crucial and affects how the input and output are processed. + +Enjoy building more robust and flexible ``FlowerCallable``s with middleware layers! diff --git a/doc/source/index.rst b/doc/source/index.rst index 9b4b5b195fbb..f7a4ec3daeda 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -91,6 +91,7 @@ Problem-oriented how-to guides show step-by-step how to achieve a specific goal. how-to-configure-logging how-to-enable-ssl-connections how-to-upgrade-to-flower-1.0 + how-to-use-built-in-middleware-layers .. toctree:: :maxdepth: 1 diff --git a/src/py/flwr/client/flower.py b/src/py/flwr/client/flower.py index 10c78ec45b44..535f096e5866 100644 --- a/src/py/flwr/client/flower.py +++ b/src/py/flwr/client/flower.py @@ -16,10 +16,11 @@ import importlib -from typing import cast +from typing import List, Optional, cast from flwr.client.message_handler.message_handler import handle -from flwr.client.typing import Bwd, ClientFn, Fwd +from flwr.client.middleware.utils import make_ffn +from flwr.client.typing import Bwd, ClientFn, Fwd, Layer class Flower: @@ -51,21 +52,23 @@ class Flower: def __init__( self, client_fn: ClientFn, # Only for backward compatibility + layers: Optional[List[Layer]] = None, ) -> None: - self.client_fn = client_fn + # Create wrapper function for `handle` + def ffn(fwd: Fwd) -> Bwd: # pylint: disable=invalid-name + task_res, state_updated = handle( + client_fn=client_fn, + state=fwd.state, + task_ins=fwd.task_ins, + ) + return Bwd(task_res=task_res, state=state_updated) + + # Wrap middleware layers around the wrapped handle function + self._call = make_ffn(ffn, layers if layers is not None else []) def __call__(self, fwd: Fwd) -> Bwd: """.""" - # Execute the task - task_res, state_updated = handle( - client_fn=self.client_fn, - state=fwd.state, - task_ins=fwd.task_ins, - ) - return Bwd( - task_res=task_res, - state=state_updated, - ) + return self._call(fwd) class LoadCallableError(Exception): diff --git a/src/py/flwr/client/middleware/__init__.py b/src/py/flwr/client/middleware/__init__.py new file mode 100644 index 000000000000..58b31296fbbe --- /dev/null +++ b/src/py/flwr/client/middleware/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Middleware layers.""" + + +from .utils import make_ffn + +__all__ = [ + "make_ffn", +] diff --git a/src/py/flwr/client/middleware/utils.py b/src/py/flwr/client/middleware/utils.py new file mode 100644 index 000000000000..d93132403c1e --- /dev/null +++ b/src/py/flwr/client/middleware/utils.py @@ -0,0 +1,35 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions for middleware layers.""" + + +from typing import List + +from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer + + +def make_ffn(ffn: FlowerCallable, layers: List[Layer]) -> FlowerCallable: + """.""" + + def wrap_ffn(_ffn: FlowerCallable, _layer: Layer) -> FlowerCallable: + def new_ffn(fwd: Fwd) -> Bwd: + return _layer(fwd, _ffn) + + return new_ffn + + for layer in reversed(layers): + ffn = wrap_ffn(ffn, layer) + + return ffn diff --git a/src/py/flwr/client/middleware/utils_test.py b/src/py/flwr/client/middleware/utils_test.py new file mode 100644 index 000000000000..9a2d888a5ecd --- /dev/null +++ b/src/py/flwr/client/middleware/utils_test.py @@ -0,0 +1,99 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the utility functions.""" + + +import unittest +from typing import List + +from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer +from flwr.client.workload_state import WorkloadState +from flwr.proto.task_pb2 import TaskIns, TaskRes + +from .utils import make_ffn + + +def make_mock_middleware(name: str, footprint: List[str]) -> Layer: + """Make a mock middleware layer.""" + + def middleware(fwd: Fwd, app: FlowerCallable) -> Bwd: + footprint.append(name) + fwd.task_ins.task_id += f"{name}" + bwd = app(fwd) + footprint.append(name) + bwd.task_res.task_id += f"{name}" + return bwd + + return middleware + + +def make_mock_app(name: str, footprint: List[str]) -> FlowerCallable: + """Make a mock app.""" + + def app(fwd: Fwd) -> Bwd: + footprint.append(name) + fwd.task_ins.task_id += f"{name}" + return Bwd(task_res=TaskRes(task_id=name), state=WorkloadState({})) + + return app + + +class TestMakeApp(unittest.TestCase): + """Tests for the `make_app` function.""" + + def test_multiple_middlewares(self) -> None: + """Test if multiple middlewares are called in the correct order.""" + # Prepare + footprint: List[str] = [] + mock_app = make_mock_app("app", footprint) + mock_middleware_names = [f"middleware{i}" for i in range(1, 15)] + mock_middleware_layers = [ + make_mock_middleware(name, footprint) for name in mock_middleware_names + ] + task_ins = TaskIns() + + # Execute + wrapped_app = make_ffn(mock_app, mock_middleware_layers) + task_res = wrapped_app(Fwd(task_ins=task_ins, state=WorkloadState({}))).task_res + + # Assert + trace = mock_middleware_names + ["app"] + self.assertEqual(footprint, trace + list(reversed(mock_middleware_names))) + # pylint: disable-next=no-member + self.assertEqual(task_ins.task_id, "".join(trace)) + self.assertEqual(task_res.task_id, "".join(reversed(trace))) + + def test_filter(self) -> None: + """Test if a middleware can filter incoming TaskIns.""" + # Prepare + footprint: List[str] = [] + mock_app = make_mock_app("app", footprint) + task_ins = TaskIns() + + def filter_layer(fwd: Fwd, _: FlowerCallable) -> Bwd: + footprint.append("filter") + fwd.task_ins.task_id += "filter" + # Skip calling app + return Bwd(task_res=TaskRes(task_id="filter"), state=WorkloadState({})) + + # Execute + wrapped_app = make_ffn(mock_app, [filter_layer]) + task_res = wrapped_app(Fwd(task_ins=task_ins, state=WorkloadState({}))).task_res + + # Assert + self.assertEqual(footprint, ["filter"]) + # pylint: disable-next=no-member + self.assertEqual(task_ins.task_id, "filter") + self.assertEqual(task_res.task_id, "filter") diff --git a/src/py/flwr/client/typing.py b/src/py/flwr/client/typing.py index 2c1f7506592c..2dd368bf6d08 100644 --- a/src/py/flwr/client/typing.py +++ b/src/py/flwr/client/typing.py @@ -41,3 +41,4 @@ class Bwd: FlowerCallable = Callable[[Fwd], Bwd] ClientFn = Callable[[str], Client] +Layer = Callable[[Fwd, FlowerCallable], Bwd]