From 2df41162b6092250c21c20f58fcdbb546c1baffa Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Fri, 1 Dec 2023 16:59:08 +0000 Subject: [PATCH] rename make_app to make_fc --- src/py/flwr/client/flower.py | 23 ++++++++------------- src/py/flwr/client/middleware/__init__.py | 4 ++-- src/py/flwr/client/middleware/utils.py | 2 +- src/py/flwr/client/middleware/utils_test.py | 6 +++--- 4 files changed, 15 insertions(+), 20 deletions(-) diff --git a/src/py/flwr/client/flower.py b/src/py/flwr/client/flower.py index e82c8b0e75e4..385f3a487351 100644 --- a/src/py/flwr/client/flower.py +++ b/src/py/flwr/client/flower.py @@ -19,7 +19,7 @@ from typing import List, Optional, cast from flwr.client.message_handler.message_handler import handle -from flwr.client.middleware.utils import make_app +from flwr.client.middleware.utils import make_fc from flwr.client.typing import Bwd, ClientFn, Fwd, Layer @@ -54,26 +54,21 @@ def __init__( client_fn: ClientFn, # Only for backward compatibility middleware: Optional[List[Layer]] = None, ) -> None: - self.client_fn = client_fn - self.mw_list = middleware if middleware is not None else [] - - def __call__(self, fwd: Fwd) -> Bwd: - """.""" - # Create wrapper function for `handle` - def handle_app(_fwd: Fwd) -> Bwd: + def fn(fwd: Fwd) -> Bwd: task_res, state_updated = handle( - client_fn=self.client_fn, - state=_fwd.state, - task_ins=_fwd.task_ins, + client_fn=client_fn, + state=fwd.state, + task_ins=fwd.task_ins, ) return Bwd(task_res=task_res, state=state_updated) # Wrap middleware layers around handle_app - app = make_app(handle_app, self.mw_list) + self._call = make_fc(fn, middleware if middleware is not None else []) - # Execute the task - return app(fwd) + def __call__(self, fwd: Fwd) -> Bwd: + """.""" + return self._call(fwd) class LoadCallableError(Exception): diff --git a/src/py/flwr/client/middleware/__init__.py b/src/py/flwr/client/middleware/__init__.py index 5b474edf1f5b..af280faeaf06 100644 --- a/src/py/flwr/client/middleware/__init__.py +++ b/src/py/flwr/client/middleware/__init__.py @@ -15,8 +15,8 @@ """Middleware layers.""" -from .utils import make_app +from .utils import make_fc __all__ = [ - "make_app", + "make_fc", ] diff --git a/src/py/flwr/client/middleware/utils.py b/src/py/flwr/client/middleware/utils.py index 5e635d49d24d..697b22fabc2c 100644 --- a/src/py/flwr/client/middleware/utils.py +++ b/src/py/flwr/client/middleware/utils.py @@ -19,7 +19,7 @@ from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer -def make_app(app: FlowerCallable, middleware_layers: List[Layer]) -> FlowerCallable: +def make_fc(app: FlowerCallable, middleware_layers: List[Layer]) -> FlowerCallable: """.""" def wrap_app(_app: FlowerCallable, _layer: Layer) -> FlowerCallable: diff --git a/src/py/flwr/client/middleware/utils_test.py b/src/py/flwr/client/middleware/utils_test.py index c9fb46487952..678cf42b00c8 100644 --- a/src/py/flwr/client/middleware/utils_test.py +++ b/src/py/flwr/client/middleware/utils_test.py @@ -22,7 +22,7 @@ from flwr.client.workload_state import WorkloadState from flwr.proto.task_pb2 import TaskIns, TaskRes -from .utils import make_app +from .utils import make_fc def make_mock_middleware(name: str, footprint: List[str]) -> Layer: @@ -65,7 +65,7 @@ def test_multiple_middlewares(self) -> None: task_ins = TaskIns() # Execute - wrapped_app = make_app(mock_app, mock_middleware_layers) + wrapped_app = make_fc(mock_app, mock_middleware_layers) task_res = wrapped_app(Fwd(task_ins=task_ins, state=WorkloadState({}))).task_res # Assert @@ -89,7 +89,7 @@ def filter_layer(fwd: Fwd, _: FlowerCallable) -> Bwd: return Bwd(task_res=TaskRes(task_id="filter"), state=WorkloadState({})) # Execute - wrapped_app = make_app(mock_app, [filter_layer]) + wrapped_app = make_fc(mock_app, [filter_layer]) task_res = wrapped_app(Fwd(task_ins=task_ins, state=WorkloadState({}))).task_res # Assert