Skip to content

Commit

Permalink
Merge branch 'main' into xgb-flwr-sim-comprehensive
Browse files Browse the repository at this point in the history
  • Loading branch information
yan-gao-GY authored Jan 25, 2024
2 parents 0c505ef + b02f263 commit 9960b1b
Show file tree
Hide file tree
Showing 46 changed files with 1,007 additions and 242 deletions.
2 changes: 1 addition & 1 deletion examples/advanced-pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ but this can be changed by removing the `--toy` argument in the script. You can

The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill).

You can also manually run `python3 server.py` and `python3 client.py --client-id <ID>` for as many clients as you want but you have to make sure that each command is run in a different terminal window (or a different computer on the network).
You can also manually run `python3 server.py` and `python3 client.py --client-id <ID>` for as many clients as you want but you have to make sure that each command is run in a different terminal window (or a different computer on the network). In addition, you can make your clients use either `EfficienNet` (default) or `AlexNet` (but all clients in the experiment should use the same). Switch between models using the `--model` flag when launching `client.py` and `server.py`.
41 changes: 26 additions & 15 deletions examples/advanced-pytorch/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import utils
from torch.utils.data import DataLoader
import torchvision.datasets
import torch
import flwr as fl
import argparse
Expand All @@ -17,26 +16,31 @@ def __init__(
trainset: datasets.Dataset,
testset: datasets.Dataset,
device: torch.device,
model_str: str,
validation_split: int = 0.1,
):
self.device = device
self.trainset = trainset
self.testset = testset
self.validation_split = validation_split
if model_str == "alexnet":
self.model = utils.load_alexnet(classes=10)
else:
self.model = utils.load_efficientnet(classes=10)

def set_parameters(self, parameters):
"""Loads a efficientnet model and replaces it parameters with the ones given."""
model = utils.load_efficientnet(classes=10)
params_dict = zip(model.state_dict().keys(), parameters)
"""Loads a alexnet or efficientnet model and replaces it parameters with the
ones given."""

params_dict = zip(self.model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)
return model
self.model.load_state_dict(state_dict, strict=True)

def fit(self, parameters, config):
"""Train parameters on the locally held training set."""

# Update local model parameters
model = self.set_parameters(parameters)
self.set_parameters(parameters)

# Get hyperparameters for this round
batch_size: int = config["batch_size"]
Expand All @@ -49,25 +53,25 @@ def fit(self, parameters, config):
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(valset, batch_size=batch_size)

results = utils.train(model, train_loader, val_loader, epochs, self.device)
results = utils.train(self.model, train_loader, val_loader, epochs, self.device)

parameters_prime = utils.get_model_params(model)
parameters_prime = utils.get_model_params(self.model)
num_examples_train = len(trainset)

return parameters_prime, num_examples_train, results

def evaluate(self, parameters, config):
"""Evaluate parameters on the locally held test set."""
# Update local model parameters
model = self.set_parameters(parameters)
self.set_parameters(parameters)

# Get config values
steps: int = config["val_steps"]

# Evaluate global model parameters on the local test data and return results
testloader = DataLoader(self.testset, batch_size=16)

loss, accuracy = utils.test(model, testloader, steps, self.device)
loss, accuracy = utils.test(self.model, testloader, steps, self.device)
return float(loss), len(self.testset), {"accuracy": float(accuracy)}


Expand Down Expand Up @@ -110,7 +114,7 @@ def main() -> None:
)
parser.add_argument(
"--toy",
action='store_true',
action="store_true",
help="Set to true to quicky run the client using only 10 datasamples. \
Useful for testing purposes. Default: False",
)
Expand All @@ -121,6 +125,14 @@ def main() -> None:
required=False,
help="Set to true to use GPU. Default: False",
)
parser.add_argument(
"--model",
type=str,
default="efficientnet",
choices=["efficientnet", "alexnet"],
help="Use either Efficientnet or Alexnet models. \
If you want to achieve differential privacy, please use the Alexnet model",
)

args = parser.parse_args()

Expand All @@ -138,9 +150,8 @@ def main() -> None:
trainset = trainset.select(range(10))
testset = testset.select(range(10))
# Start Flower client
client = CifarClient(trainset, testset, device)

fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client)
client = CifarClient(trainset, testset, device, args.model).to_client()
fl.client.start_client(server_address="127.0.0.1:8080", client=client)


if __name__ == "__main__":
Expand Down
5 changes: 0 additions & 5 deletions examples/advanced-pytorch/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@
set -e
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/

# Download the EfficientNetB0 model
python -c "import torch; torch.hub.load( \
'NVIDIA/DeepLearningExamples:torchhub', \
'nvidia_efficientnet_b0', pretrained=True)"

python server.py --toy &
sleep 10 # Sleep for 10s to give the server enough time to start and dowload the dataset

Expand Down
19 changes: 15 additions & 4 deletions examples/advanced-pytorch/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,32 @@ def main():
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--toy",
action='store_true',
action="store_true",
help="Set to true to use only 10 datasamples for validation. \
Useful for testing purposes. Default: False",
)
parser.add_argument(
"--model",
type=str,
default="efficientnet",
choices=["efficientnet", "alexnet"],
help="Use either Efficientnet or Alexnet models. \
If you want to achieve differential privacy, please use the Alexnet model",
)

args = parser.parse_args()

model = utils.load_efficientnet(classes=10)
if args.model == "alexnet":
model = utils.load_alexnet(classes=10)
else:
model = utils.load_efficientnet(classes=10)

model_parameters = [val.cpu().numpy() for _, val in model.state_dict().items()]

# Create strategy
strategy = fl.server.strategy.FedAvg(
fraction_fit=0.2,
fraction_evaluate=0.2,
fraction_fit=1.0,
fraction_evaluate=1.0,
min_fit_clients=2,
min_evaluate_clients=2,
min_available_clients=10,
Expand Down
71 changes: 29 additions & 42 deletions examples/advanced-pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import torch
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop
from torch.utils.data import DataLoader

from torchvision.models import efficientnet_b0, AlexNet
import warnings

from flwr_datasets import FederatedDataset


warnings.filterwarnings("ignore")


Expand All @@ -28,24 +28,27 @@ def load_centralized_data():

def apply_transforms(batch):
"""Apply transforms to the partition from FederatedDataset."""
pytorch_transforms = Compose([
Resize(256),
CenterCrop(224),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
pytorch_transforms = Compose(
[
Resize(256),
CenterCrop(224),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
return batch


def train(net, trainloader, valloader, epochs,
device: torch.device = torch.device("cpu")):
def train(
net, trainloader, valloader, epochs, device: torch.device = torch.device("cpu")
):
"""Train the network on the training set."""
print("Starting training...")
net.to(device) # move model to GPU if available
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(
net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4
net.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4
)
net.train()
for _ in range(epochs):
Expand All @@ -71,8 +74,9 @@ def train(net, trainloader, valloader, epochs,
return results


def test(net, testloader, steps: int = None,
device: torch.device = torch.device("cpu")):
def test(
net, testloader, steps: int = None, device: torch.device = torch.device("cpu")
):
"""Validate the network on the entire test set."""
print("Starting evalutation...")
net.to(device) # move model to GPU if available
Expand All @@ -94,38 +98,21 @@ def test(net, testloader, steps: int = None,
return loss, accuracy


def replace_classifying_layer(efficientnet_model, num_classes: int = 10):
"""Replaces the final layer of the classifier."""
num_features = efficientnet_model.classifier.fc.in_features
efficientnet_model.classifier.fc = torch.nn.Linear(num_features, num_classes)


def load_efficientnet(entrypoint: str = "nvidia_efficientnet_b0", classes: int = None):
"""Loads pretrained efficientnet model from torch hub. Replaces final classifying
layer if classes is specified.
Args:
entrypoint: EfficientNet model to download.
For supported entrypoints, please refer
https://pytorch.org/hub/nvidia_deeplearningexamples_efficientnet/
classes: Number of classes in final classifying layer. Leave as None to get
the downloaded
model untouched.
Returns:
EfficientNet Model
Note: One alternative implementation can be found at
https://github.com/lukemelas/EfficientNet-PyTorch
"""
efficientnet = torch.hub.load(
"NVIDIA/DeepLearningExamples:torchhub", entrypoint, pretrained=True
)

if classes is not None:
replace_classifying_layer(efficientnet, classes)
def load_efficientnet(classes: int = 10):
"""Loads EfficienNetB0 from TorchVision."""
efficientnet = efficientnet_b0(pretrained=True)
# Re-init output linear layer with the right number of classes
model_classes = efficientnet.classifier[1].in_features
if classes != model_classes:
efficientnet.classifier[1] = torch.nn.Linear(model_classes, classes)
return efficientnet


def get_model_params(model):
"""Returns a model's parameters."""
return [val.cpu().numpy() for _, val in model.state_dict().items()]


def load_alexnet(classes):
"""Load AlexNet model from TorchVision."""
return AlexNet(num_classes=classes)
6 changes: 3 additions & 3 deletions examples/advanced-tensorflow/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def main() -> None:
)
parser.add_argument(
"--toy",
action='store_true',
action="store_true",
help="Set to true to quicky run the client using only 10 datasamples. "
"Useful for testing purposes. Default: False",
)
Expand All @@ -106,9 +106,9 @@ def main() -> None:
x_test, y_test = x_test[:10], y_test[:10]

# Start Flower client
client = CifarClient(model, x_train, y_train, x_test, y_test)
client = CifarClient(model, x_train, y_train, x_test, y_test).to_client()

fl.client.start_numpy_client(
fl.client.start_client(
server_address="127.0.0.1:8080",
client=client,
root_certificates=Path(".cache/certificates/ca.crt").read_bytes(),
Expand Down
2 changes: 1 addition & 1 deletion examples/custom-metrics/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ def evaluate(self, parameters, config):


# Start Flower client
fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient())
fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client())
4 changes: 2 additions & 2 deletions examples/embedded-devices/client_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,11 @@ def main():
trainsets, valsets, _ = prepare_dataset(use_mnist)

# Start Flower client setting its associated data partition
fl.client.start_numpy_client(
fl.client.start_client(
server_address=args.server_address,
client=FlowerClient(
trainset=trainsets[args.cid], valset=valsets[args.cid], use_mnist=use_mnist
),
).to_client(),
)


Expand Down
4 changes: 2 additions & 2 deletions examples/embedded-devices/client_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ def main():
trainset, valset = partitions[args.cid]

# Start Flower client setting its associated data partition
fl.client.start_numpy_client(
fl.client.start_client(
server_address=args.server_address,
client=FlowerClient(trainset=trainset, valset=valset, use_mnist=use_mnist),
client=FlowerClient(trainset=trainset, valset=valset, use_mnist=use_mnist).to_client(),
)


Expand Down
2 changes: 1 addition & 1 deletion examples/flower-in-30-minutes/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@
"\n",
" return FlowerClient(\n",
" trainloader=trainloaders[int(cid)], vallodaer=valloaders[int(cid)]\n",
" )\n",
" ).to_client()\n",
"\n",
" return client_fn\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/mt-pytorch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def evaluate(self, parameters, config):


# Start Flower client
fl.client.start_numpy_client(
fl.client.start_client(
server_address="0.0.0.0:9092", # "0.0.0.0:9093" for REST
client=FlowerClient(),
client=FlowerClient().to_client(),
transport="grpc-rere", # "rest" for REST
)
4 changes: 2 additions & 2 deletions examples/opacus/dp_cifar_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def load_data():

model = Net()
trainloader, testloader, sample_rate = load_data()
fl.client.start_numpy_client(
fl.client.start_client(
server_address="127.0.0.1:8080",
client=DPCifarClient(model, trainloader, testloader),
client=DPCifarClient(model, trainloader, testloader).to_client(),
)
Loading

0 comments on commit 9960b1b

Please sign in to comment.