Skip to content

Commit

Permalink
refactor: use alexnet to replace efficientnet (#2782)
Browse files Browse the repository at this point in the history
Co-authored-by: helin1 <[email protected]>
Co-authored-by: jafermarq <[email protected]>
  • Loading branch information
3 people authored Jan 24, 2024
1 parent 472f42a commit 88f2e3e
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 53 deletions.
2 changes: 1 addition & 1 deletion examples/advanced-pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ but this can be changed by removing the `--toy` argument in the script. You can

The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill).

You can also manually run `python3 server.py` and `python3 client.py --client-id <ID>` for as many clients as you want but you have to make sure that each command is run in a different terminal window (or a different computer on the network).
You can also manually run `python3 server.py` and `python3 client.py --client-id <ID>` for as many clients as you want but you have to make sure that each command is run in a different terminal window (or a different computer on the network). In addition, you can make your clients use either `EfficienNet` (default) or `AlexNet` (but all clients in the experiment should use the same). Switch between models using the `--model` flag when launching `client.py` and `server.py`.
36 changes: 24 additions & 12 deletions examples/advanced-pytorch/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import utils
from torch.utils.data import DataLoader
import torchvision.datasets
import torch
import flwr as fl
import argparse
Expand All @@ -17,26 +16,31 @@ def __init__(
trainset: datasets.Dataset,
testset: datasets.Dataset,
device: torch.device,
model_str: str,
validation_split: int = 0.1,
):
self.device = device
self.trainset = trainset
self.testset = testset
self.validation_split = validation_split
if model_str == "alexnet":
self.model = utils.load_alexnet(classes=10)
else:
self.model = utils.load_efficientnet(classes=10)

def set_parameters(self, parameters):
"""Loads a efficientnet model and replaces it parameters with the ones given."""
model = utils.load_efficientnet(classes=10)
params_dict = zip(model.state_dict().keys(), parameters)
"""Loads a alexnet or efficientnet model and replaces it parameters with the
ones given."""

params_dict = zip(self.model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)
return model
self.model.load_state_dict(state_dict, strict=True)

def fit(self, parameters, config):
"""Train parameters on the locally held training set."""

# Update local model parameters
model = self.set_parameters(parameters)
self.set_parameters(parameters)

# Get hyperparameters for this round
batch_size: int = config["batch_size"]
Expand All @@ -49,25 +53,25 @@ def fit(self, parameters, config):
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(valset, batch_size=batch_size)

results = utils.train(model, train_loader, val_loader, epochs, self.device)
results = utils.train(self.model, train_loader, val_loader, epochs, self.device)

parameters_prime = utils.get_model_params(model)
parameters_prime = utils.get_model_params(self.model)
num_examples_train = len(trainset)

return parameters_prime, num_examples_train, results

def evaluate(self, parameters, config):
"""Evaluate parameters on the locally held test set."""
# Update local model parameters
model = self.set_parameters(parameters)
self.set_parameters(parameters)

# Get config values
steps: int = config["val_steps"]

# Evaluate global model parameters on the local test data and return results
testloader = DataLoader(self.testset, batch_size=16)

loss, accuracy = utils.test(model, testloader, steps, self.device)
loss, accuracy = utils.test(self.model, testloader, steps, self.device)
return float(loss), len(self.testset), {"accuracy": float(accuracy)}


Expand Down Expand Up @@ -121,6 +125,14 @@ def main() -> None:
required=False,
help="Set to true to use GPU. Default: False",
)
parser.add_argument(
"--model",
type=str,
default="efficientnet",
choices=["efficientnet", "alexnet"],
help="Use either Efficientnet or Alexnet models. \
If you want to achieve differential privacy, please use the Alexnet model",
)

args = parser.parse_args()

Expand All @@ -138,7 +150,7 @@ def main() -> None:
trainset = trainset.select(range(10))
testset = testset.select(range(10))
# Start Flower client
client = CifarClient(trainset, testset, device)
client = CifarClient(trainset, testset, device, args.model)

fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client)

Expand Down
5 changes: 0 additions & 5 deletions examples/advanced-pytorch/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@
set -e
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/

# Download the EfficientNetB0 model
python -c "import torch; torch.hub.load( \
'NVIDIA/DeepLearningExamples:torchhub', \
'nvidia_efficientnet_b0', pretrained=True)"

python server.py --toy &
sleep 10 # Sleep for 10s to give the server enough time to start and dowload the dataset

Expand Down
17 changes: 14 additions & 3 deletions examples/advanced-pytorch/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,28 @@ def main():
help="Set to true to use only 10 datasamples for validation. \
Useful for testing purposes. Default: False",
)
parser.add_argument(
"--model",
type=str,
default="efficientnet",
choices=["efficientnet", "alexnet"],
help="Use either Efficientnet or Alexnet models. \
If you want to achieve differential privacy, please use the Alexnet model",
)

args = parser.parse_args()

model = utils.load_efficientnet(classes=10)
if args.model == "alexnet":
model = utils.load_alexnet(classes=10)
else:
model = utils.load_efficientnet(classes=10)

model_parameters = [val.cpu().numpy() for _, val in model.state_dict().items()]

# Create strategy
strategy = fl.server.strategy.FedAvg(
fraction_fit=0.2,
fraction_evaluate=0.2,
fraction_fit=1.0,
fraction_evaluate=1.0,
min_fit_clients=2,
min_evaluate_clients=2,
min_available_clients=10,
Expand Down
47 changes: 15 additions & 32 deletions examples/advanced-pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import torch
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop
from torch.utils.data import DataLoader

from torchvision.models import efficientnet_b0, AlexNet
import warnings

from flwr_datasets import FederatedDataset


warnings.filterwarnings("ignore")


Expand Down Expand Up @@ -48,7 +48,7 @@ def train(
net.to(device) # move model to GPU if available
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(
net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4
net.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4
)
net.train()
for _ in range(epochs):
Expand Down Expand Up @@ -98,38 +98,21 @@ def test(
return loss, accuracy


def replace_classifying_layer(efficientnet_model, num_classes: int = 10):
"""Replaces the final layer of the classifier."""
num_features = efficientnet_model.classifier.fc.in_features
efficientnet_model.classifier.fc = torch.nn.Linear(num_features, num_classes)


def load_efficientnet(entrypoint: str = "nvidia_efficientnet_b0", classes: int = None):
"""Loads pretrained efficientnet model from torch hub. Replaces final classifying
layer if classes is specified.
Args:
entrypoint: EfficientNet model to download.
For supported entrypoints, please refer
https://pytorch.org/hub/nvidia_deeplearningexamples_efficientnet/
classes: Number of classes in final classifying layer. Leave as None to get
the downloaded
model untouched.
Returns:
EfficientNet Model
Note: One alternative implementation can be found at
https://github.com/lukemelas/EfficientNet-PyTorch
"""
efficientnet = torch.hub.load(
"NVIDIA/DeepLearningExamples:torchhub", entrypoint, pretrained=True
)

if classes is not None:
replace_classifying_layer(efficientnet, classes)
def load_efficientnet(classes: int = 10):
"""Loads EfficienNetB0 from TorchVision."""
efficientnet = efficientnet_b0(pretrained=True)
# Re-init output linear layer with the right number of classes
model_classes = efficientnet.classifier[1].in_features
if classes != model_classes:
efficientnet.classifier[1] = torch.nn.Linear(model_classes, classes)
return efficientnet


def get_model_params(model):
"""Returns a model's parameters."""
return [val.cpu().numpy() for _, val in model.state_dict().items()]


def load_alexnet(classes):
"""Load AlexNet model from TorchVision."""
return AlexNet(num_classes=classes)

0 comments on commit 88f2e3e

Please sign in to comment.