Skip to content

Commit

Permalink
use FlowerCallable instead of App
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Dec 1, 2023
1 parent 9a825ba commit d4a9f7d
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 37 deletions.
3 changes: 1 addition & 2 deletions src/py/flwr/client/flower.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
from typing import List, Optional, cast

from flwr.client.message_handler.message_handler import handle
from flwr.client.middleware.typing import Layer
from flwr.client.middleware.utils import make_app
from flwr.client.typing import Bwd, ClientFn, Fwd
from flwr.client.typing import Bwd, ClientFn, Fwd, Layer


class Flower:
Expand Down
3 changes: 0 additions & 3 deletions src/py/flwr/client/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@
"""Middleware layers."""


from .typing import App, Layer
from .utils import make_app

__all__ = [
"App",
"Layer",
"make_app",
]
22 changes: 0 additions & 22 deletions src/py/flwr/client/middleware/typing.py

This file was deleted.

8 changes: 3 additions & 5 deletions src/py/flwr/client/middleware/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@

from typing import List

from flwr.client.typing import Bwd, Fwd
from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer

from .typing import App, Layer


def make_app(app: App, middleware_layers: List[Layer]) -> App:
def make_app(app: FlowerCallable, middleware_layers: List[Layer]) -> FlowerCallable:
"""."""

def wrap_app(_app: App, _layer: Layer) -> App:
def wrap_app(_app: FlowerCallable, _layer: Layer) -> FlowerCallable:
def new_app(fwd: Fwd) -> Bwd:
return _layer(fwd, _app)

Expand Down
9 changes: 4 additions & 5 deletions src/py/flwr/client/middleware/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,17 @@
import unittest
from typing import List

from flwr.client.typing import Bwd, Fwd
from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer
from flwr.client.workload_state import WorkloadState
from flwr.proto.task_pb2 import TaskIns, TaskRes

from .typing import App, Layer
from .utils import make_app


def make_mock_middleware(name: str, footprint: List[str]) -> Layer:
"""Make a mock middleware layer."""

def middleware(fwd: Fwd, app: App) -> Bwd:
def middleware(fwd: Fwd, app: FlowerCallable) -> Bwd:
footprint.append(name)
fwd.task_ins.task_id += f"{name}"
bwd = app(fwd)
Expand All @@ -40,7 +39,7 @@ def middleware(fwd: Fwd, app: App) -> Bwd:
return middleware


def make_mock_app(name: str, footprint: List[str]) -> App:
def make_mock_app(name: str, footprint: List[str]) -> FlowerCallable:
"""Make a mock app."""

def app(fwd: Fwd) -> Bwd:
Expand Down Expand Up @@ -83,7 +82,7 @@ def test_filter(self) -> None:
mock_app = make_mock_app("app", footprint)
task_ins = TaskIns()

def filter_layer(fwd: Fwd, _: App) -> Bwd:
def filter_layer(fwd: Fwd, _: FlowerCallable) -> Bwd:
footprint.append("filter")
fwd.task_ins.task_id += "filter"
# Skip calling app
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/client/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ class Bwd:

FlowerCallable = Callable[[Fwd], Bwd]
ClientFn = Callable[[str], Client]
Layer = Callable[[Fwd, FlowerCallable], Bwd]

0 comments on commit d4a9f7d

Please sign in to comment.