From ee35b91d19dd3abf98f8cd08279ae33983156792 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Thu, 23 Nov 2023 15:35:51 +0100 Subject: [PATCH] Migrate PyTorch quickstart to Flower Datasets (#2314) Co-authored-by: jafermarq --- examples/quickstart-pytorch/README.md | 11 +++-- examples/quickstart-pytorch/client.py | 48 +++++++++++++++----- examples/quickstart-pytorch/pyproject.toml | 5 +- examples/quickstart-pytorch/requirements.txt | 5 +- examples/quickstart-pytorch/run.sh | 7 +-- 5 files changed, 51 insertions(+), 25 deletions(-) diff --git a/examples/quickstart-pytorch/README.md b/examples/quickstart-pytorch/README.md index f748894f4971..6de0dcf7ab32 100644 --- a/examples/quickstart-pytorch/README.md +++ b/examples/quickstart-pytorch/README.md @@ -1,7 +1,6 @@ # Flower Example using PyTorch -This introductory example to Flower uses PyTorch, but deep knowledge of PyTorch is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. -Running this example in itself is quite easy. +This introductory example to Flower uses PyTorch, but deep knowledge of PyTorch is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset. ## Project Setup @@ -56,18 +55,20 @@ Afterwards you are ready to start the Flower server as well as the clients. You python3 server.py ``` -Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminal windows and run the following commands. +Now you are ready to start the Flower clients which will participate in the learning. We need to specify the node id to +use different partitions of the data on different nodes. To do so simply open two more terminal windows and run the +following commands. Start client 1 in the first terminal: ```shell -python3 client.py +python3 client.py --node-id 0 ``` Start client 2 in the second terminal: ```shell -python3 client.py +python3 client.py --node-id 1 ``` You will see that PyTorch is starting a federated training. Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch) for a detailed explanation. diff --git a/examples/quickstart-pytorch/client.py b/examples/quickstart-pytorch/client.py index 6db7c8a855a0..8ce19a45403d 100644 --- a/examples/quickstart-pytorch/client.py +++ b/examples/quickstart-pytorch/client.py @@ -1,12 +1,13 @@ +import argparse import warnings from collections import OrderedDict import flwr as fl +from flwr_datasets import FederatedDataset import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader -from torchvision.datasets import CIFAR10 from torchvision.transforms import Compose, Normalize, ToTensor from tqdm import tqdm @@ -45,7 +46,9 @@ def train(net, trainloader, epochs): criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) for _ in range(epochs): - for images, labels in tqdm(trainloader): + for batch in tqdm(trainloader, "Training"): + images = batch["img"] + labels = batch["label"] optimizer.zero_grad() criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward() optimizer.step() @@ -56,30 +59,53 @@ def test(net, testloader): criterion = torch.nn.CrossEntropyLoss() correct, loss = 0, 0.0 with torch.no_grad(): - for images, labels in tqdm(testloader): + for batch in tqdm(testloader, "Testing"): + images = batch["img"] + labels = batch["label"] outputs = net(images.to(DEVICE)) - labels = labels.to(DEVICE) loss += criterion(outputs, labels).item() correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() accuracy = correct / len(testloader.dataset) return loss, accuracy -def load_data(): - """Load CIFAR-10 (training and test set).""" - trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) - trainset = CIFAR10("./data", train=True, download=True, transform=trf) - testset = CIFAR10("./data", train=False, download=True, transform=trf) - return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset) +def load_data(node_id): + """Load partition CIFAR10 data.""" + fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3}) + partition = fds.load_partition(node_id) + # Divide data on each node: 80% train, 20% test + partition_train_test = partition.train_test_split(test_size=0.2) + pytorch_transforms = Compose( + [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + + def apply_transforms(batch): + """Apply transforms to the partition from FederatedDataset.""" + batch["img"] = [pytorch_transforms(img) for img in batch["img"]] + return batch + + partition_train_test = partition_train_test.with_transform(apply_transforms) + trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True) + testloader = DataLoader(partition_train_test["test"], batch_size=32) + return trainloader, testloader # ############################################################################# # 2. Federation of the pipeline with Flower # ############################################################################# +# Get node id +parser = argparse.ArgumentParser(description="Flower") +parser.add_argument( + "--node-id", + choices=[0, 1, 2], + type=int, + help="Partition of the dataset divided into 3 iid partitions created artificially.") +node_id = parser.parse_args().node_id + # Load model and data (simple CNN, CIFAR-10) net = Net().to(DEVICE) -trainloader, testloader = load_data() +trainloader, testloader = load_data(node_id=node_id) # Define Flower client diff --git a/examples/quickstart-pytorch/pyproject.toml b/examples/quickstart-pytorch/pyproject.toml index affdfee26d47..ec6a3af8c5b4 100644 --- a/examples/quickstart-pytorch/pyproject.toml +++ b/examples/quickstart-pytorch/pyproject.toml @@ -11,6 +11,7 @@ authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" -torch = "1.13.1" -torchvision = "0.14.1" +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } +torch = "2.1.1" +torchvision = "0.16.1" tqdm = "4.65.0" diff --git a/examples/quickstart-pytorch/requirements.txt b/examples/quickstart-pytorch/requirements.txt index 797ca6db6244..4e321e2cd0c2 100644 --- a/examples/quickstart-pytorch/requirements.txt +++ b/examples/quickstart-pytorch/requirements.txt @@ -1,4 +1,5 @@ flwr>=1.0, <2.0 -torch==1.13.1 -torchvision==0.14.1 +flwr-datasets[vision]>=0.0.2, <1.0.0 +torch==2.1.1 +torchvision==0.16.1 tqdm==4.65.0 diff --git a/examples/quickstart-pytorch/run.sh b/examples/quickstart-pytorch/run.sh index d2bf34f834b1..cdace99bb8df 100755 --- a/examples/quickstart-pytorch/run.sh +++ b/examples/quickstart-pytorch/run.sh @@ -2,16 +2,13 @@ set -e cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ -# Download the CIFAR-10 dataset -python -c "from torchvision.datasets import CIFAR10; CIFAR10('./data', download=True)" - echo "Starting server" python server.py & sleep 3 # Sleep for 3s to give the server enough time to start -for i in `seq 0 1`; do +for i in $(seq 0 1); do echo "Starting client $i" - python client.py & + python client.py --node-id "$i" & done # Enable CTRL+C to stop all background processes