Skip to content

Commit

Permalink
Merge branch 'main' into add-vfl-example
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll authored Nov 23, 2023
2 parents 2b25cef + ee35b91 commit 6d04a1d
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 25 deletions.
11 changes: 6 additions & 5 deletions examples/quickstart-pytorch/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Flower Example using PyTorch

This introductory example to Flower uses PyTorch, but deep knowledge of PyTorch is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case.
Running this example in itself is quite easy.
This introductory example to Flower uses PyTorch, but deep knowledge of PyTorch is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset.

## Project Setup

Expand Down Expand Up @@ -56,18 +55,20 @@ Afterwards you are ready to start the Flower server as well as the clients. You
python3 server.py
```

Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminal windows and run the following commands.
Now you are ready to start the Flower clients which will participate in the learning. We need to specify the node id to
use different partitions of the data on different nodes. To do so simply open two more terminal windows and run the
following commands.

Start client 1 in the first terminal:

```shell
python3 client.py
python3 client.py --node-id 0
```

Start client 2 in the second terminal:

```shell
python3 client.py
python3 client.py --node-id 1
```

You will see that PyTorch is starting a federated training. Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch) for a detailed explanation.
48 changes: 37 additions & 11 deletions examples/quickstart-pytorch/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import argparse
import warnings
from collections import OrderedDict

import flwr as fl
from flwr_datasets import FederatedDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor
from tqdm import tqdm

Expand Down Expand Up @@ -45,7 +46,9 @@ def train(net, trainloader, epochs):
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for _ in range(epochs):
for images, labels in tqdm(trainloader):
for batch in tqdm(trainloader, "Training"):
images = batch["img"]
labels = batch["label"]
optimizer.zero_grad()
criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
optimizer.step()
Expand All @@ -56,30 +59,53 @@ def test(net, testloader):
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
with torch.no_grad():
for images, labels in tqdm(testloader):
for batch in tqdm(testloader, "Testing"):
images = batch["img"]
labels = batch["label"]
outputs = net(images.to(DEVICE))
labels = labels.to(DEVICE)
loss += criterion(outputs, labels).item()
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
accuracy = correct / len(testloader.dataset)
return loss, accuracy


def load_data():
"""Load CIFAR-10 (training and test set)."""
trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = CIFAR10("./data", train=True, download=True, transform=trf)
testset = CIFAR10("./data", train=False, download=True, transform=trf)
return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset)
def load_data(node_id):
"""Load partition CIFAR10 data."""
fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3})
partition = fds.load_partition(node_id)
# Divide data on each node: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2)
pytorch_transforms = Compose(
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

def apply_transforms(batch):
"""Apply transforms to the partition from FederatedDataset."""
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
return batch

partition_train_test = partition_train_test.with_transform(apply_transforms)
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
testloader = DataLoader(partition_train_test["test"], batch_size=32)
return trainloader, testloader


# #############################################################################
# 2. Federation of the pipeline with Flower
# #############################################################################

# Get node id
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--node-id",
choices=[0, 1, 2],
type=int,
help="Partition of the dataset divided into 3 iid partitions created artificially.")
node_id = parser.parse_args().node_id

# Load model and data (simple CNN, CIFAR-10)
net = Net().to(DEVICE)
trainloader, testloader = load_data()
trainloader, testloader = load_data(node_id=node_id)


# Define Flower client
Expand Down
5 changes: 3 additions & 2 deletions examples/quickstart-pytorch/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ authors = ["The Flower Authors <[email protected]>"]
[tool.poetry.dependencies]
python = ">=3.8,<3.11"
flwr = ">=1.0,<2.0"
torch = "1.13.1"
torchvision = "0.14.1"
flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" }
torch = "2.1.1"
torchvision = "0.16.1"
tqdm = "4.65.0"
5 changes: 3 additions & 2 deletions examples/quickstart-pytorch/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
flwr>=1.0, <2.0
torch==1.13.1
torchvision==0.14.1
flwr-datasets[vision]>=0.0.2, <1.0.0
torch==2.1.1
torchvision==0.16.1
tqdm==4.65.0
7 changes: 2 additions & 5 deletions examples/quickstart-pytorch/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@
set -e
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/

# Download the CIFAR-10 dataset
python -c "from torchvision.datasets import CIFAR10; CIFAR10('./data', download=True)"

echo "Starting server"
python server.py &
sleep 3 # Sleep for 3s to give the server enough time to start

for i in `seq 0 1`; do
for i in $(seq 0 1); do
echo "Starting client $i"
python client.py &
python client.py --node-id "$i" &
done

# Enable CTRL+C to stop all background processes
Expand Down

0 comments on commit 6d04a1d

Please sign in to comment.