diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 81bbee148c95..7ce7d51d3d4b 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -23,8 +23,8 @@ from typing import Callable, ContextManager, Optional, Tuple, Union from flwr.client.client import Client -from flwr.client.flower import Bwd, Flower, Fwd -from flwr.client.typing import ClientFn +from flwr.client.flower import Flower +from flwr.client.typing import Bwd, ClientFn, Fwd from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event from flwr.common.address import parse_address from flwr.common.constant import ( diff --git a/src/py/flwr/client/flower.py b/src/py/flwr/client/flower.py index f0d4ce122524..5b083ee11b9f 100644 --- a/src/py/flwr/client/flower.py +++ b/src/py/flwr/client/flower.py @@ -16,32 +16,10 @@ import importlib -from dataclasses import dataclass -from typing import Callable, cast +from typing import cast from flwr.client.message_handler.message_handler import handle -from flwr.client.typing import ClientFn -from flwr.client.workload_state import WorkloadState -from flwr.proto.task_pb2 import TaskIns, TaskRes - - -@dataclass -class Fwd: - """.""" - - task_ins: TaskIns - state: WorkloadState - - -@dataclass -class Bwd: - """.""" - - task_res: TaskRes - state: WorkloadState - - -FlowerCallable = Callable[[Fwd], Bwd] +from flwr.client.typing import Bwd, ClientFn, Fwd class Flower: diff --git a/src/py/flwr/client/typing.py b/src/py/flwr/client/typing.py index 7ee6f069768c..2c1f7506592c 100644 --- a/src/py/flwr/client/typing.py +++ b/src/py/flwr/client/typing.py @@ -14,8 +14,30 @@ # ============================================================================== """Custom types for Flower clients.""" +from dataclasses import dataclass from typing import Callable +from flwr.client.workload_state import WorkloadState +from flwr.proto.task_pb2 import TaskIns, TaskRes + from .client import Client as Client + +@dataclass +class Fwd: + """.""" + + task_ins: TaskIns + state: WorkloadState + + +@dataclass +class Bwd: + """.""" + + task_res: TaskRes + state: WorkloadState + + +FlowerCallable = Callable[[Fwd], Bwd] ClientFn = Callable[[str], Client] diff --git a/src/py/flwr/flower/__init__.py b/src/py/flwr/flower/__init__.py index 090c78062d02..892a7ce5afdc 100644 --- a/src/py/flwr/flower/__init__.py +++ b/src/py/flwr/flower/__init__.py @@ -15,9 +15,9 @@ """Flower callable package.""" -from flwr.client.flower import Bwd as Bwd from flwr.client.flower import Flower as Flower -from flwr.client.flower import Fwd as Fwd +from flwr.client.typing import Bwd as Bwd +from flwr.client.typing import Fwd as Fwd __all__ = [ "Flower",