Skip to content

Commit

Permalink
Migrate from click to argparse
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Nov 23, 2023
1 parent acaf3aa commit bfa744c
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
18 changes: 9 additions & 9 deletions examples/quickstart-pytorch/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import warnings
from collections import OrderedDict

import click
import flwr as fl
from flwr_datasets import FederatedDataset
import torch
Expand Down Expand Up @@ -69,16 +69,10 @@ def test(net, testloader):
return loss, accuracy


@click.command()
@click.option(
"--node-id",
type=click.Choice(["0", "1", "2"]),
help="Partition of the dataset, which is divided into 3 iid partitions created artificially",
)
def load_data(node_id):
"""Load partition CIFAR10 data."""
fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3})
partition = fds.load_partition(int(node_id), "train")
partition = fds.load_partition(node_id)
# Divide data on each node: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2)
pytorch_transforms = Compose(
Expand All @@ -100,9 +94,15 @@ def apply_transforms(batch):
# 2. Federation of the pipeline with Flower
# #############################################################################

# Get node id
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument("--node-id", choices=[0, 1, 2], type=int, help="Partition of the "
"dataset, which is divided into 3 iid partitions created artificially.")
node_id = parser.parse_args().node_id

# Load model and data (simple CNN, CIFAR-10)
net = Net().to(DEVICE)
trainloader, testloader = load_data(standalone_mode=False)
trainloader, testloader = load_data(node_id=node_id)


# Define Flower client
Expand Down
1 change: 0 additions & 1 deletion examples/quickstart-pytorch/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,3 @@ flwr-datasets = {extras = ["vision"], version = ">=0.0.2,<1.0.0" }
torch = "2.1.1"
torchvision = "0.16.1"
tqdm = "4.65.0"
click = "8.1.7"
1 change: 0 additions & 1 deletion examples/quickstart-pytorch/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@ flwr-datasets>[vision]=0.0.2, <1.0.0
torch==2.1.1
torchvision==0.16.1
tqdm==4.65.0
click==8.1.7

0 comments on commit bfa744c

Please sign in to comment.