diff --git a/examples/quickstart-pytorch/client.py b/examples/quickstart-pytorch/client.py index b5971496aace..b95ce1f00644 100644 --- a/examples/quickstart-pytorch/client.py +++ b/examples/quickstart-pytorch/client.py @@ -1,7 +1,7 @@ +import argparse import warnings from collections import OrderedDict -import click import flwr as fl from flwr_datasets import FederatedDataset import torch @@ -69,16 +69,10 @@ def test(net, testloader): return loss, accuracy -@click.command() -@click.option( - "--node-id", - type=click.Choice(["0", "1", "2"]), - help="Partition of the dataset, which is divided into 3 iid partitions created artificially", -) def load_data(node_id): """Load partition CIFAR10 data.""" fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3}) - partition = fds.load_partition(int(node_id), "train") + partition = fds.load_partition(node_id) # Divide data on each node: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2) pytorch_transforms = Compose( @@ -100,9 +94,15 @@ def apply_transforms(batch): # 2. Federation of the pipeline with Flower # ############################################################################# +# Get node id +parser = argparse.ArgumentParser(description="Flower") +parser.add_argument("--node-id", choices=[0, 1, 2], type=int, help="Partition of the " +"dataset, which is divided into 3 iid partitions created artificially.") +node_id = parser.parse_args().node_id + # Load model and data (simple CNN, CIFAR-10) net = Net().to(DEVICE) -trainloader, testloader = load_data(standalone_mode=False) +trainloader, testloader = load_data(node_id=node_id) # Define Flower client diff --git a/examples/quickstart-pytorch/pyproject.toml b/examples/quickstart-pytorch/pyproject.toml index e7ca9738c9e4..582d45093da5 100644 --- a/examples/quickstart-pytorch/pyproject.toml +++ b/examples/quickstart-pytorch/pyproject.toml @@ -15,4 +15,3 @@ flwr-datasets = {extras = ["vision"], version = ">=0.0.2,<1.0.0" } torch = "2.1.1" torchvision = "0.16.1" tqdm = "4.65.0" -click = "8.1.7" diff --git a/examples/quickstart-pytorch/requirements.txt b/examples/quickstart-pytorch/requirements.txt index 5dc153d117e5..fc1f90777ad1 100644 --- a/examples/quickstart-pytorch/requirements.txt +++ b/examples/quickstart-pytorch/requirements.txt @@ -3,4 +3,3 @@ flwr-datasets>[vision]=0.0.2, <1.0.0 torch==2.1.1 torchvision==0.16.1 tqdm==4.65.0 -click==8.1.7