Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Sep 20, 2023
1 parent 0c0f9cf commit 3bd0a04
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 33 deletions.
6 changes: 6 additions & 0 deletions examples/simulation-pytorch/sim.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@
"source": [
"from datasets import Dataset\n",
"from flwr_datasets import FederatedDataset\n",
"from datasets.utils.logging import disable_progress_bar\n",
"\n",
"# Let's set a simulation involving a total of 100 clients\n",
"NUM_CLIENTS = 100\n",
Expand Down Expand Up @@ -568,13 +569,18 @@
"# client needs exclusive access to these many resources in order to run\n",
"client_resources = {\"num_cpus\": 1, \"num_gpus\": 0.0}\n",
"\n",
"# Let's disable tqdm progress bar in the main thread (used by the server)\n",
"disable_progress_bar()\n",
"\n",
"history = fl.simulation.start_simulation(\n",
" client_fn=client_fn_callback, # a callback to construct a client\n",
" num_clients=NUM_CLIENTS, # total number of clients in the experiment\n",
" config=fl.server.ServerConfig(num_rounds=10), # let's run for 10 rounds\n",
" strategy=strategy, # the strategy that will orchestrate the whole FL pipeline\n",
" client_resources=client_resources,\n",
" actor_kwargs={\n",
" \"on_actor_init_fn\": disable_progress_bar # disable tqdm on each actor/process spawning virtual clients\n",
" },\n",
")"
]
},
Expand Down
17 changes: 12 additions & 5 deletions examples/simulation-pytorch/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from flwr.common.typing import Scalar

from datasets import Dataset
from datasets.utils.logging import disable_progress_bar
from flwr_datasets import FederatedDataset

from utils import Net, train, test, get_mnist_transforms

from utils import Net, train, test, mnist_transforms

parser = argparse.ArgumentParser(description="Flower Simulation with PyTorch")

Expand Down Expand Up @@ -101,11 +101,12 @@ def client_fn(cid: str) -> fl.client.Client:

# Now we apply the transform to each batch.
trainset = trainset.map(
lambda img: {"img": get_mnist_transforms()(img)}, input_columns="image"
lambda img: {"image": mnist_transforms(img)},
input_columns="image",
).with_format("torch")

valset = valset.map(
lambda img: {"img": get_mnist_transforms()(img)}, input_columns="image"
lambda img: {"image": mnist_transforms(img)}, input_columns="image"
).with_format("torch")

# Create and return client
Expand Down Expand Up @@ -160,9 +161,12 @@ def evaluate(

# Apply transform to dataset
testset = centralized_testset.map(
lambda img: {"img": get_mnist_transforms()(img)}, input_columns="image"
lambda img: {"image": mnist_transforms(img)}, input_columns="image"
).with_format("torch")

# Disable tqdm for dataset preprocessing
disable_progress_bar()

testloader = DataLoader(testset, batch_size=50)
loss, accuracy = test(model, testloader, device=device)

Expand Down Expand Up @@ -206,6 +210,9 @@ def main():
client_resources=client_resources,
config=fl.server.ServerConfig(num_rounds=args.num_rounds),
strategy=strategy,
actor_kwargs={
"on_actor_init_fn": disable_progress_bar # disable tqdm on each actor/process spawning virtual clients
},
)


Expand Down
17 changes: 6 additions & 11 deletions examples/simulation-pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from torchvision.transforms import ToTensor, Normalize, Compose


# transformation to convert images to tensors and apply normalization
mnist_transforms = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])


# Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')
class Net(nn.Module):
def __init__(self, num_classes: int = 10) -> None:
Expand Down Expand Up @@ -33,7 +37,7 @@ def train(net, trainloader, optim, epochs, device: str):
net.train()
for _ in range(epochs):
for batch in trainloader:
images, labels = batch["img"].to(device), batch["label"].to(device)
images, labels = batch["image"].to(device), batch["label"].to(device)
optim.zero_grad()
loss = criterion(net(images), labels)
loss.backward()
Expand All @@ -48,19 +52,10 @@ def test(net, testloader, device: str):
net.eval()
with torch.no_grad():
for data in testloader:
images, labels = data["img"].to(device), data["label"].to(device)
images, labels = data["image"].to(device), data["label"].to(device)
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
correct += (predicted == labels).sum().item()
accuracy = correct / len(testloader.dataset)
return loss, accuracy


def get_mnist_transforms():
"""Get transformation for MNIST dataset."""

# transformation to convert images to tensors and apply normalization
tr = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

return tr
19 changes: 4 additions & 15 deletions examples/simulation-tensorflow/sim.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
"outputs": [],
"source": [
"class FlowerClient(fl.client.NumPyClient):\n",
" def __init__(self, trainset, valset) -> None:\n",
" def __init__(self, trainset, valset) -> None:\n",
" # Create model\n",
" self.model = get_model()\n",
" self.x_train, self.y_train = trainset[\"image\"], trainset[\"label\"]\n",
Expand Down Expand Up @@ -189,7 +189,6 @@
"\n",
" trainset = client_dataset_splits[\"train\"].with_format(\"tf\")\n",
" valset = client_dataset_splits[\"test\"].with_format(\"tf\")\n",
" \n",
"\n",
" # Create and return client\n",
" return FlowerClient(trainset, valset)\n",
Expand Down Expand Up @@ -219,7 +218,9 @@
" ):\n",
" model = get_model() # Construct the model\n",
" model.set_weights(parameters) # Update model with the latest parameters\n",
" loss, accuracy = model.evaluate(testset[\"image\"], testset[\"label\"], verbose=VERBOSE)\n",
" loss, accuracy = model.evaluate(\n",
" testset[\"image\"], testset[\"label\"], verbose=VERBOSE\n",
" )\n",
" return loss, {\"accuracy\": accuracy}\n",
"\n",
" return evaluate"
Expand Down Expand Up @@ -330,18 +331,6 @@
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
Expand Down
5 changes: 3 additions & 2 deletions examples/simulation-tensorflow/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def client_fn(cid: str) -> fl.client.Client:

trainset = client_dataset_splits["train"].with_format("tf")
valset = client_dataset_splits["test"].with_format("tf")


# Create and return client
return FlowerClient(trainset, valset)
Expand Down Expand Up @@ -137,7 +136,9 @@ def evaluate(
):
model = get_model() # Construct the model
model.set_weights(parameters) # Update model with the latest parameters
loss, accuracy = model.evaluate(testset["image"], testset["label"], verbose=VERBOSE)
loss, accuracy = model.evaluate(
testset["image"], testset["label"], verbose=VERBOSE
)
return loss, {"accuracy": accuracy}

return evaluate
Expand Down

0 comments on commit 3bd0a04

Please sign in to comment.