Skip to content

Commit

Permalink
Decouple client callable (#2393)
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Nov 23, 2023
1 parent 821d843 commit 1f9fa75
Show file tree
Hide file tree
Showing 11 changed files with 495 additions and 18 deletions.
49 changes: 49 additions & 0 deletions examples/mt-pytorch-callable/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Deploy 🧪

🧪 = this page covers experimental features that might change in future versions of Flower

This how-to guide describes the deployment of a long-running Flower server.

## Preconditions

Let's assume the following project structure:

```bash
$ tree .
.
└── client.py
├── driver.py
├── requirements.txt
```

## Install dependencies

```bash
pip install -r requirements.txt
```

## Start the long-running Flower server

```bash
flower-server --insecure
```

## Start the long-running Flower client

In a new terminal window, start the first long-running Flower client:

```bash
flower-client --callable client:flower
```

In yet another new terminal window, start the second long-running Flower client:

```bash
flower-client --callable client:flower
```

## Start the Driver script

```bash
python driver.py
```
123 changes: 123 additions & 0 deletions examples/mt-pytorch-callable/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import warnings
from collections import OrderedDict

import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor
from tqdm import tqdm


# #############################################################################
# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader
# #############################################################################

warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Net(nn.Module):
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""

def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)


def train(net, trainloader, epochs):
"""Train the model on the training set."""
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for _ in range(epochs):
for images, labels in tqdm(trainloader):
optimizer.zero_grad()
criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
optimizer.step()


def test(net, testloader):
"""Validate the model on the test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
with torch.no_grad():
for images, labels in tqdm(testloader):
outputs = net(images.to(DEVICE))
labels = labels.to(DEVICE)
loss += criterion(outputs, labels).item()
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
accuracy = correct / len(testloader.dataset)
return loss, accuracy


def load_data():
"""Load CIFAR-10 (training and test set)."""
trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = CIFAR10("./data", train=True, download=True, transform=trf)
testset = CIFAR10("./data", train=False, download=True, transform=trf)
return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset)


# #############################################################################
# 2. Federation of the pipeline with Flower
# #############################################################################

# Load model and data (simple CNN, CIFAR-10)
net = Net().to(DEVICE)
trainloader, testloader = load_data()


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]

def set_parameters(self, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)

def fit(self, parameters, config):
self.set_parameters(parameters)
train(net, trainloader, epochs=1)
return self.get_parameters(config={}), len(trainloader.dataset), {}

def evaluate(self, parameters, config):
self.set_parameters(parameters)
loss, accuracy = test(net, testloader)
return loss, len(testloader.dataset), {"accuracy": accuracy}


def client_fn(cid: str):
"""."""
return FlowerClient().to_client()


# To run this: `flower-client --callable client:flower`
flower = fl.flower.Flower(
client_fn=client_fn,
)


if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
server_address="0.0.0.0:9092",
client=FlowerClient().to_client(),
transport="grpc-rere",
)
25 changes: 25 additions & 0 deletions examples/mt-pytorch-callable/driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import List, Tuple

import flwr as fl
from flwr.common import Metrics


# Define metric aggregation function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
# Multiply accuracy of each client by number of examples used
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]

# Aggregate and return custom metric (weighted average)
return {"accuracy": sum(accuracies) / sum(examples)}


# Define strategy
strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average)

# Start Flower driver
fl.driver.start_driver(
server_address="0.0.0.0:9091",
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
)
16 changes: 16 additions & 0 deletions examples/mt-pytorch-callable/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "quickstart-pytorch"
version = "0.1.0"
description = "PyTorch Federated Learning Quickstart with Flower"
authors = ["The Flower Authors <hello@flower.dev>"]

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
flwr = { path = "../../", develop = true, extras = ["simulation", "rest"] }
torch = "1.13.1"
torchvision = "0.14.1"
tqdm = "4.65.0"
4 changes: 4 additions & 0 deletions examples/mt-pytorch-callable/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
flwr>=1.0, <2.0
torch==1.13.1
torchvision==0.14.1
tqdm==4.65.0
20 changes: 20 additions & 0 deletions examples/mt-pytorch-callable/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash
set -e
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/

# Download the CIFAR-10 dataset
python -c "from torchvision.datasets import CIFAR10; CIFAR10('./data', download=True)"

echo "Starting server"
python server.py &
sleep 3 # Sleep for 3s to give the server enough time to start

for i in `seq 0 1`; do
echo "Starting client $i"
python client.py &
done

# Enable CTRL+C to stop all background processes
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM
# Wait for all background processes to complete
wait
25 changes: 25 additions & 0 deletions examples/mt-pytorch-callable/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import List, Tuple

import flwr as fl
from flwr.common import Metrics


# Define metric aggregation function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
# Multiply accuracy of each client by number of examples used
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]

# Aggregate and return custom metric (weighted average)
return {"accuracy": sum(accuracies) / sum(examples)}


# Define strategy
strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average)

# Start Flower server
fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
)
3 changes: 2 additions & 1 deletion src/py/flwr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

from flwr.common.version import package_version as _package_version

from . import client, common, driver, server, simulation
from . import client, common, driver, flower, server, simulation

__all__ = [
"client",
"common",
"driver",
"flower",
"server",
"simulation",
]
Expand Down
Loading

0 comments on commit 1f9fa75

Please sign in to comment.