diff --git a/examples/mt-pytorch-callable/README.md b/examples/mt-pytorch-callable/README.md new file mode 100644 index 000000000000..65ef000c26f2 --- /dev/null +++ b/examples/mt-pytorch-callable/README.md @@ -0,0 +1,49 @@ +# Deploy ๐Ÿงช + +๐Ÿงช = this page covers experimental features that might change in future versions of Flower + +This how-to guide describes the deployment of a long-running Flower server. + +## Preconditions + +Let's assume the following project structure: + +```bash +$ tree . +. +โ””โ”€โ”€ client.py +โ”œโ”€โ”€ driver.py +โ”œโ”€โ”€ requirements.txt +``` + +## Install dependencies + +```bash +pip install -r requirements.txt +``` + +## Start the long-running Flower server + +```bash +flower-server --insecure +``` + +## Start the long-running Flower client + +In a new terminal window, start the first long-running Flower client: + +```bash +flower-client --callable client:flower +``` + +In yet another new terminal window, start the second long-running Flower client: + +```bash +flower-client --callable client:flower +``` + +## Start the Driver script + +```bash +python driver.py +``` diff --git a/examples/mt-pytorch-callable/client.py b/examples/mt-pytorch-callable/client.py new file mode 100644 index 000000000000..6f9747784ae0 --- /dev/null +++ b/examples/mt-pytorch-callable/client.py @@ -0,0 +1,123 @@ +import warnings +from collections import OrderedDict + +import flwr as fl +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision.datasets import CIFAR10 +from torchvision.transforms import Compose, Normalize, ToTensor +from tqdm import tqdm + + +# ############################################################################# +# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader +# ############################################################################# + +warnings.filterwarnings("ignore", category=UserWarning) +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class Net(nn.Module): + """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" + + def __init__(self) -> None: + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + +def train(net, trainloader, epochs): + """Train the model on the training set.""" + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + for _ in range(epochs): + for images, labels in tqdm(trainloader): + optimizer.zero_grad() + criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward() + optimizer.step() + + +def test(net, testloader): + """Validate the model on the test set.""" + criterion = torch.nn.CrossEntropyLoss() + correct, loss = 0, 0.0 + with torch.no_grad(): + for images, labels in tqdm(testloader): + outputs = net(images.to(DEVICE)) + labels = labels.to(DEVICE) + loss += criterion(outputs, labels).item() + correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() + accuracy = correct / len(testloader.dataset) + return loss, accuracy + + +def load_data(): + """Load CIFAR-10 (training and test set).""" + trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + trainset = CIFAR10("./data", train=True, download=True, transform=trf) + testset = CIFAR10("./data", train=False, download=True, transform=trf) + return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset) + + +# ############################################################################# +# 2. Federation of the pipeline with Flower +# ############################################################################# + +# Load model and data (simple CNN, CIFAR-10) +net = Net().to(DEVICE) +trainloader, testloader = load_data() + + +# Define Flower client +class FlowerClient(fl.client.NumPyClient): + def get_parameters(self, config): + return [val.cpu().numpy() for _, val in net.state_dict().items()] + + def set_parameters(self, parameters): + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True) + + def fit(self, parameters, config): + self.set_parameters(parameters) + train(net, trainloader, epochs=1) + return self.get_parameters(config={}), len(trainloader.dataset), {} + + def evaluate(self, parameters, config): + self.set_parameters(parameters) + loss, accuracy = test(net, testloader) + return loss, len(testloader.dataset), {"accuracy": accuracy} + + +def client_fn(cid: str): + """.""" + return FlowerClient().to_client() + + +# To run this: `flower-client --callable client:flower` +flower = fl.flower.Flower( + client_fn=client_fn, +) + + +if __name__ == "__main__": + # Start Flower client + fl.client.start_client( + server_address="0.0.0.0:9092", + client=FlowerClient().to_client(), + transport="grpc-rere", + ) diff --git a/examples/mt-pytorch-callable/driver.py b/examples/mt-pytorch-callable/driver.py new file mode 100644 index 000000000000..1248672b6813 --- /dev/null +++ b/examples/mt-pytorch-callable/driver.py @@ -0,0 +1,25 @@ +from typing import List, Tuple + +import flwr as fl +from flwr.common import Metrics + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + # Multiply accuracy of each client by number of examples used + accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] + + # Aggregate and return custom metric (weighted average) + return {"accuracy": sum(accuracies) / sum(examples)} + + +# Define strategy +strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) + +# Start Flower driver +fl.driver.start_driver( + server_address="0.0.0.0:9091", + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, +) diff --git a/examples/mt-pytorch-callable/pyproject.toml b/examples/mt-pytorch-callable/pyproject.toml new file mode 100644 index 000000000000..0d1a91836006 --- /dev/null +++ b/examples/mt-pytorch-callable/pyproject.toml @@ -0,0 +1,16 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "quickstart-pytorch" +version = "0.1.0" +description = "PyTorch Federated Learning Quickstart with Flower" +authors = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = ">=3.8,<3.11" +flwr = { path = "../../", develop = true, extras = ["simulation", "rest"] } +torch = "1.13.1" +torchvision = "0.14.1" +tqdm = "4.65.0" diff --git a/examples/mt-pytorch-callable/requirements.txt b/examples/mt-pytorch-callable/requirements.txt new file mode 100644 index 000000000000..797ca6db6244 --- /dev/null +++ b/examples/mt-pytorch-callable/requirements.txt @@ -0,0 +1,4 @@ +flwr>=1.0, <2.0 +torch==1.13.1 +torchvision==0.14.1 +tqdm==4.65.0 diff --git a/examples/mt-pytorch-callable/run.sh b/examples/mt-pytorch-callable/run.sh new file mode 100755 index 000000000000..d2bf34f834b1 --- /dev/null +++ b/examples/mt-pytorch-callable/run.sh @@ -0,0 +1,20 @@ +#!/bin/bash +set -e +cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ + +# Download the CIFAR-10 dataset +python -c "from torchvision.datasets import CIFAR10; CIFAR10('./data', download=True)" + +echo "Starting server" +python server.py & +sleep 3 # Sleep for 3s to give the server enough time to start + +for i in `seq 0 1`; do + echo "Starting client $i" + python client.py & +done + +# Enable CTRL+C to stop all background processes +trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM +# Wait for all background processes to complete +wait diff --git a/examples/mt-pytorch-callable/server.py b/examples/mt-pytorch-callable/server.py new file mode 100644 index 000000000000..fe691a88aba0 --- /dev/null +++ b/examples/mt-pytorch-callable/server.py @@ -0,0 +1,25 @@ +from typing import List, Tuple + +import flwr as fl +from flwr.common import Metrics + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + # Multiply accuracy of each client by number of examples used + accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] + + # Aggregate and return custom metric (weighted average) + return {"accuracy": sum(accuracies) / sum(examples)} + + +# Define strategy +strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) + +# Start Flower server +fl.server.start_server( + server_address="0.0.0.0:8080", + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, +) diff --git a/src/py/flwr/__init__.py b/src/py/flwr/__init__.py index d3cbf00747a4..e05799280339 100644 --- a/src/py/flwr/__init__.py +++ b/src/py/flwr/__init__.py @@ -17,12 +17,13 @@ from flwr.common.version import package_version as _package_version -from . import client, common, driver, server, simulation +from . import client, common, driver, flower, server, simulation __all__ = [ "client", "common", "driver", + "flower", "server", "simulation", ] diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 0013b74c631c..b39dbbfc33c0 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -22,6 +22,7 @@ from typing import Callable, ContextManager, Optional, Tuple, Union from flwr.client.client import Client +from flwr.client.flower import Bwd, Flower, Fwd from flwr.client.typing import ClientFn from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event from flwr.common.address import parse_address @@ -32,13 +33,15 @@ TRANSPORT_TYPE_REST, TRANSPORT_TYPES, ) -from flwr.common.logger import log +from flwr.common.logger import log, warn_experimental_feature from flwr.proto.task_pb2 import TaskIns, TaskRes +from .flower import load_callable from .grpc_client.connection import grpc_connection from .grpc_rere_client.connection import grpc_request_response -from .message_handler.message_handler import handle, handle_control_message +from .message_handler.message_handler import handle_control_message from .numpy_client import NumPyClient +from .workload_state import WorkloadState def run_client() -> None: @@ -48,6 +51,22 @@ def run_client() -> None: args = _parse_args_client().parse_args() print(args.server) + print(args.callable_dir) + print(args.callable) + + callable_dir = args.callable_dir + if callable_dir is not None: + sys.path.insert(0, callable_dir) + + def _load() -> Flower: + flower: Flower = load_callable(args.callable) + return flower + + return start_client( + server_address=args.server, + load_callable_fn=_load, + transport="grpc-rere", # Only + ) def _parse_args_client() -> argparse.ArgumentParser: @@ -58,8 +77,18 @@ def _parse_args_client() -> argparse.ArgumentParser: parser.add_argument( "--server", - help="Server address", default="0.0.0.0:9092", + help="Server address", + ) + parser.add_argument( + "--callable", + help="For example: `client:flower` or `project.package.module:wrapper.flower`", + ) + parser.add_argument( + "--callable-dir", + default="", + help="Add specified directory to the PYTHONPATH and load callable from there." + " Default: current working directory.", ) return parser @@ -84,6 +113,7 @@ def _check_actionable_client( def start_client( *, server_address: str, + load_callable_fn: Optional[Callable[[], Flower]] = None, client_fn: Optional[ClientFn] = None, client: Optional[Client] = None, grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, @@ -98,6 +128,8 @@ def start_client( The IPv4 or IPv6 address of the server. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"[::]:8080"`. + load_callable_fn : Optional[Callable[[], Flower]] (default: None) + ... client_fn : Optional[ClientFn] A callable that instantiates a Client. (default: None) client : Optional[flwr.client.Client] @@ -146,20 +178,31 @@ class `flwr.client.Client` (default: None) """ event(EventType.START_CLIENT_ENTER) - _check_actionable_client(client, client_fn) + if load_callable_fn is None: + _check_actionable_client(client, client_fn) - if client_fn is None: - # Wrap `Client` instance in `client_fn` - def single_client_factory( - cid: str, # pylint: disable=unused-argument - ) -> Client: - if client is None: # Added this to keep mypy happy - raise Exception( - "Both `client_fn` and `client` are `None`, but one is required" - ) - return client # Always return the same instance + if client_fn is None: + # Wrap `Client` instance in `client_fn` + def single_client_factory( + cid: str, # pylint: disable=unused-argument + ) -> Client: + if client is None: # Added this to keep mypy happy + raise Exception( + "Both `client_fn` and `client` are `None`, but one is required" + ) + return client # Always return the same instance + + client_fn = single_client_factory + + def _load_app() -> Flower: + return Flower(client_fn=client_fn) - client_fn = single_client_factory + load_callable_fn = _load_app + else: + warn_experimental_feature("`load_callable_fn`") + + # At this point, only `load_callable_fn` should be used + # Both `client` and `client_fn` must not be used directly # Initialize connection context manager connection, address = _init_connection(transport, server_address) @@ -190,11 +233,18 @@ def single_client_factory( send(task_res) break + # Load app + app: Flower = load_callable_fn() + # Handle task message - task_res = handle(client_fn, task_ins) + fwd_msg: Fwd = Fwd( + task_ins=task_ins, + state=WorkloadState(state={}), + ) + bwd_msg: Bwd = app(fwd=fwd_msg) # Send - send(task_res) + send(bwd_msg.task_res) # Unregister node if delete_node is not None: diff --git a/src/py/flwr/client/flower.py b/src/py/flwr/client/flower.py new file mode 100644 index 000000000000..9eeb41887e24 --- /dev/null +++ b/src/py/flwr/client/flower.py @@ -0,0 +1,138 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower callable.""" + + +import importlib +from dataclasses import dataclass +from typing import Callable, cast + +from flwr.client.message_handler.message_handler import handle +from flwr.client.typing import ClientFn +from flwr.client.workload_state import WorkloadState +from flwr.proto.task_pb2 import TaskIns, TaskRes + + +@dataclass +class Fwd: + """.""" + + task_ins: TaskIns + state: WorkloadState + + +@dataclass +class Bwd: + """.""" + + task_res: TaskRes + state: WorkloadState + + +FlowerCallable = Callable[[Fwd], Bwd] + + +class Flower: + """Flower callable. + + Examples + -------- + Assuming a typical client implementation in `FlowerClient`, you can wrap it in a + Flower callable as follows: + + >>> class FlowerClient(NumPyClient): + >>> # ... + >>> + >>> def client_fn(cid): + >>> return FlowerClient().to_client() + >>> + >>> flower = Flower(client_fn) + + If the above code is in a Python module called `client`, it can be started as + follows: + + >>> flower-client --callable client:flower + + In this `client:flower` example, `client` refers to the Python module in which the + previous code lives in. `flower` refers to the global attribute `flower` that points + to an object of type `Flower` (a Flower callable). + """ + + def __init__( + self, + client_fn: ClientFn, # Only for backward compatibility + ) -> None: + self.client_fn = client_fn + + def __call__(self, fwd: Fwd) -> Bwd: + """.""" + # Execute the task + task_res = handle( + client_fn=self.client_fn, + task_ins=fwd.task_ins, + ) + return Bwd( + task_res=task_res, + state=WorkloadState(state={}), + ) + + +class LoadCallableError(Exception): + """.""" + + +def load_callable(module_attribute_str: str) -> Flower: + """Load the `Flower` object specified in a module attribute string. + + The module/attribute string should have the form :. Valid + examples include `client:flower` and `project.package.module:wrapper.flower`. It + must refer to a module on the PYTHONPATH, the module needs to have the specified + attribute, and the attribute must be of type `Flower`. + """ + module_str, _, attributes_str = module_attribute_str.partition(":") + if not module_str: + raise LoadCallableError( + f"Missing module in {module_attribute_str}", + ) from None + if not attributes_str: + raise LoadCallableError( + f"Missing attribute in {module_attribute_str}", + ) from None + + # Load module + try: + module = importlib.import_module(module_str) + except ModuleNotFoundError: + raise LoadCallableError( + f"Unable to load module {module_str}", + ) from None + + # Recursively load attribute + attribute = module + try: + for attribute_str in attributes_str.split("."): + attribute = getattr(attribute, attribute_str) + except AttributeError: + raise LoadCallableError( + f"Unable to load attribute {attributes_str} from module {module_str}", + ) from None + + # Check type + if not isinstance(attribute, Flower): + raise LoadCallableError( + f"Attribute {attributes_str} is not of type {Flower}", + ) from None + + return cast(Flower, attribute) diff --git a/src/py/flwr/flower/__init__.py b/src/py/flwr/flower/__init__.py new file mode 100644 index 000000000000..090c78062d02 --- /dev/null +++ b/src/py/flwr/flower/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2020 Adap GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower callable package.""" + + +from flwr.client.flower import Bwd as Bwd +from flwr.client.flower import Flower as Flower +from flwr.client.flower import Fwd as Fwd + +__all__ = [ + "Flower", + "Fwd", + "Bwd", +]