-
Notifications
You must be signed in to change notification settings - Fork 910
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
821d843
commit 1f9fa75
Showing
11 changed files
with
495 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Deploy 🧪 | ||
|
||
🧪 = this page covers experimental features that might change in future versions of Flower | ||
|
||
This how-to guide describes the deployment of a long-running Flower server. | ||
|
||
## Preconditions | ||
|
||
Let's assume the following project structure: | ||
|
||
```bash | ||
$ tree . | ||
. | ||
└── client.py | ||
├── driver.py | ||
├── requirements.txt | ||
``` | ||
|
||
## Install dependencies | ||
|
||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Start the long-running Flower server | ||
|
||
```bash | ||
flower-server --insecure | ||
``` | ||
|
||
## Start the long-running Flower client | ||
|
||
In a new terminal window, start the first long-running Flower client: | ||
|
||
```bash | ||
flower-client --callable client:flower | ||
``` | ||
|
||
In yet another new terminal window, start the second long-running Flower client: | ||
|
||
```bash | ||
flower-client --callable client:flower | ||
``` | ||
|
||
## Start the Driver script | ||
|
||
```bash | ||
python driver.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import warnings | ||
from collections import OrderedDict | ||
|
||
import flwr as fl | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.utils.data import DataLoader | ||
from torchvision.datasets import CIFAR10 | ||
from torchvision.transforms import Compose, Normalize, ToTensor | ||
from tqdm import tqdm | ||
|
||
|
||
# ############################################################################# | ||
# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader | ||
# ############################################################################# | ||
|
||
warnings.filterwarnings("ignore", category=UserWarning) | ||
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
|
||
class Net(nn.Module): | ||
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" | ||
|
||
def __init__(self) -> None: | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(3, 6, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 5 * 5, 120) | ||
self.fc2 = nn.Linear(120, 84) | ||
self.fc3 = nn.Linear(84, 10) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = x.view(-1, 16 * 5 * 5) | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
return self.fc3(x) | ||
|
||
|
||
def train(net, trainloader, epochs): | ||
"""Train the model on the training set.""" | ||
criterion = torch.nn.CrossEntropyLoss() | ||
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) | ||
for _ in range(epochs): | ||
for images, labels in tqdm(trainloader): | ||
optimizer.zero_grad() | ||
criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward() | ||
optimizer.step() | ||
|
||
|
||
def test(net, testloader): | ||
"""Validate the model on the test set.""" | ||
criterion = torch.nn.CrossEntropyLoss() | ||
correct, loss = 0, 0.0 | ||
with torch.no_grad(): | ||
for images, labels in tqdm(testloader): | ||
outputs = net(images.to(DEVICE)) | ||
labels = labels.to(DEVICE) | ||
loss += criterion(outputs, labels).item() | ||
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() | ||
accuracy = correct / len(testloader.dataset) | ||
return loss, accuracy | ||
|
||
|
||
def load_data(): | ||
"""Load CIFAR-10 (training and test set).""" | ||
trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | ||
trainset = CIFAR10("./data", train=True, download=True, transform=trf) | ||
testset = CIFAR10("./data", train=False, download=True, transform=trf) | ||
return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset) | ||
|
||
|
||
# ############################################################################# | ||
# 2. Federation of the pipeline with Flower | ||
# ############################################################################# | ||
|
||
# Load model and data (simple CNN, CIFAR-10) | ||
net = Net().to(DEVICE) | ||
trainloader, testloader = load_data() | ||
|
||
|
||
# Define Flower client | ||
class FlowerClient(fl.client.NumPyClient): | ||
def get_parameters(self, config): | ||
return [val.cpu().numpy() for _, val in net.state_dict().items()] | ||
|
||
def set_parameters(self, parameters): | ||
params_dict = zip(net.state_dict().keys(), parameters) | ||
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) | ||
net.load_state_dict(state_dict, strict=True) | ||
|
||
def fit(self, parameters, config): | ||
self.set_parameters(parameters) | ||
train(net, trainloader, epochs=1) | ||
return self.get_parameters(config={}), len(trainloader.dataset), {} | ||
|
||
def evaluate(self, parameters, config): | ||
self.set_parameters(parameters) | ||
loss, accuracy = test(net, testloader) | ||
return loss, len(testloader.dataset), {"accuracy": accuracy} | ||
|
||
|
||
def client_fn(cid: str): | ||
""".""" | ||
return FlowerClient().to_client() | ||
|
||
|
||
# To run this: `flower-client --callable client:flower` | ||
flower = fl.flower.Flower( | ||
client_fn=client_fn, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
# Start Flower client | ||
fl.client.start_client( | ||
server_address="0.0.0.0:9092", | ||
client=FlowerClient().to_client(), | ||
transport="grpc-rere", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from typing import List, Tuple | ||
|
||
import flwr as fl | ||
from flwr.common import Metrics | ||
|
||
|
||
# Define metric aggregation function | ||
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: | ||
# Multiply accuracy of each client by number of examples used | ||
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] | ||
examples = [num_examples for num_examples, _ in metrics] | ||
|
||
# Aggregate and return custom metric (weighted average) | ||
return {"accuracy": sum(accuracies) / sum(examples)} | ||
|
||
|
||
# Define strategy | ||
strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) | ||
|
||
# Start Flower driver | ||
fl.driver.start_driver( | ||
server_address="0.0.0.0:9091", | ||
config=fl.server.ServerConfig(num_rounds=3), | ||
strategy=strategy, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
[build-system] | ||
requires = ["poetry-core>=1.4.0"] | ||
build-backend = "poetry.core.masonry.api" | ||
|
||
[tool.poetry] | ||
name = "quickstart-pytorch" | ||
version = "0.1.0" | ||
description = "PyTorch Federated Learning Quickstart with Flower" | ||
authors = ["The Flower Authors <hello@flower.dev>"] | ||
|
||
[tool.poetry.dependencies] | ||
python = ">=3.8,<3.11" | ||
flwr = { path = "../../", develop = true, extras = ["simulation", "rest"] } | ||
torch = "1.13.1" | ||
torchvision = "0.14.1" | ||
tqdm = "4.65.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
flwr>=1.0, <2.0 | ||
torch==1.13.1 | ||
torchvision==0.14.1 | ||
tqdm==4.65.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#!/bin/bash | ||
set -e | ||
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ | ||
|
||
# Download the CIFAR-10 dataset | ||
python -c "from torchvision.datasets import CIFAR10; CIFAR10('./data', download=True)" | ||
|
||
echo "Starting server" | ||
python server.py & | ||
sleep 3 # Sleep for 3s to give the server enough time to start | ||
|
||
for i in `seq 0 1`; do | ||
echo "Starting client $i" | ||
python client.py & | ||
done | ||
|
||
# Enable CTRL+C to stop all background processes | ||
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM | ||
# Wait for all background processes to complete | ||
wait |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from typing import List, Tuple | ||
|
||
import flwr as fl | ||
from flwr.common import Metrics | ||
|
||
|
||
# Define metric aggregation function | ||
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: | ||
# Multiply accuracy of each client by number of examples used | ||
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] | ||
examples = [num_examples for num_examples, _ in metrics] | ||
|
||
# Aggregate and return custom metric (weighted average) | ||
return {"accuracy": sum(accuracies) / sum(examples)} | ||
|
||
|
||
# Define strategy | ||
strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) | ||
|
||
# Start Flower server | ||
fl.server.start_server( | ||
server_address="0.0.0.0:8080", | ||
config=fl.server.ServerConfig(num_rounds=3), | ||
strategy=strategy, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.