Skip to content

Commit

Permalink
Make examples use start_client(). (#2718)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Jan 25, 2024
1 parent 3b5df2a commit b02f263
Show file tree
Hide file tree
Showing 28 changed files with 48 additions and 49 deletions.
5 changes: 2 additions & 3 deletions examples/advanced-pytorch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,8 @@ def main() -> None:
trainset = trainset.select(range(10))
testset = testset.select(range(10))
# Start Flower client
client = CifarClient(trainset, testset, device, args.model)

fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client)
client = CifarClient(trainset, testset, device, args.model).to_client()
fl.client.start_client(server_address="127.0.0.1:8080", client=client)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions examples/advanced-tensorflow/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ def main() -> None:
x_test, y_test = x_test[:10], y_test[:10]

# Start Flower client
client = CifarClient(model, x_train, y_train, x_test, y_test)
client = CifarClient(model, x_train, y_train, x_test, y_test).to_client()

fl.client.start_numpy_client(
fl.client.start_client(
server_address="127.0.0.1:8080",
client=client,
root_certificates=Path(".cache/certificates/ca.crt").read_bytes(),
Expand Down
2 changes: 1 addition & 1 deletion examples/custom-metrics/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ def evaluate(self, parameters, config):


# Start Flower client
fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient())
fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client())
4 changes: 2 additions & 2 deletions examples/embedded-devices/client_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,11 @@ def main():
trainsets, valsets, _ = prepare_dataset(use_mnist)

# Start Flower client setting its associated data partition
fl.client.start_numpy_client(
fl.client.start_client(
server_address=args.server_address,
client=FlowerClient(
trainset=trainsets[args.cid], valset=valsets[args.cid], use_mnist=use_mnist
),
).to_client(),
)


Expand Down
4 changes: 2 additions & 2 deletions examples/embedded-devices/client_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ def main():
trainset, valset = partitions[args.cid]

# Start Flower client setting its associated data partition
fl.client.start_numpy_client(
fl.client.start_client(
server_address=args.server_address,
client=FlowerClient(trainset=trainset, valset=valset, use_mnist=use_mnist),
client=FlowerClient(trainset=trainset, valset=valset, use_mnist=use_mnist).to_client(),
)


Expand Down
2 changes: 1 addition & 1 deletion examples/flower-in-30-minutes/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@
"\n",
" return FlowerClient(\n",
" trainloader=trainloaders[int(cid)], vallodaer=valloaders[int(cid)]\n",
" )\n",
" ).to_client()\n",
"\n",
" return client_fn\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/mt-pytorch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def evaluate(self, parameters, config):


# Start Flower client
fl.client.start_numpy_client(
fl.client.start_client(
server_address="0.0.0.0:9092", # "0.0.0.0:9093" for REST
client=FlowerClient(),
client=FlowerClient().to_client(),
transport="grpc-rere", # "rest" for REST
)
4 changes: 2 additions & 2 deletions examples/opacus/dp_cifar_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def load_data():

model = Net()
trainloader, testloader, sample_rate = load_data()
fl.client.start_numpy_client(
fl.client.start_client(
server_address="127.0.0.1:8080",
client=DPCifarClient(model, trainloader, testloader),
client=DPCifarClient(model, trainloader, testloader).to_client(),
)
16 changes: 8 additions & 8 deletions examples/opacus/dp_cifar_simulation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import math
from collections import OrderedDict
from typing import Callable, Optional, Tuple
from typing import Callable, Dict, Optional, Tuple

import flwr as fl
import numpy as np
import torch
import torchvision.transforms as transforms
from opacus.dp_model_inspector import DPModelInspector
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from flwr.common.typing import Scalar

from dp_cifar_main import DEVICE, PARAMS, DPCifarClient, Net, test

Expand All @@ -23,8 +23,6 @@ def client_fn(cid: str) -> fl.client.Client:
# Load model.
model = Net()
# Check model is compatible with Opacus.
# inspector = DPModelInspector()
# print(f"Is the model valid? {inspector.validate(model)}")

# Load data partition (divide CIFAR10 into NUM_CLIENTS distinct partitions, using 30% for validation).
transform = transforms.Compose(
Expand All @@ -45,12 +43,14 @@ def client_fn(cid: str) -> fl.client.Client:
client_trainloader = DataLoader(client_trainset, PARAMS["batch_size"])
client_testloader = DataLoader(client_testset, PARAMS["batch_size"])

return DPCifarClient(model, client_trainloader, client_testloader)
return DPCifarClient(model, client_trainloader, client_testloader).to_client()


# Define an evaluation function for centralized evaluation (using whole CIFAR10 testset).
def get_evaluate_fn() -> Callable[[fl.common.NDArrays], Optional[Tuple[float, float]]]:
def evaluate(weights: fl.common.NDArrays) -> Optional[Tuple[float, float]]:
def evaluate(
server_round: int, parameters: fl.common.NDArrays, config: Dict[str, Scalar]
):
transform = transforms.Compose(
[
transforms.ToTensor(),
Expand All @@ -63,7 +63,7 @@ def evaluate(weights: fl.common.NDArrays) -> Optional[Tuple[float, float]]:
state_dict = OrderedDict(
{
k: torch.tensor(np.atleast_1d(v))
for k, v in zip(model.state_dict().keys(), weights)
for k, v in zip(model.state_dict().keys(), parameters)
}
)
model.load_state_dict(state_dict, strict=True)
Expand All @@ -82,7 +82,7 @@ def main() -> None:
client_fn=client_fn,
num_clients=NUM_CLIENTS,
client_resources={"num_cpus": 1},
num_rounds=3,
config=fl.server.ServerConfig(num_rounds=3),
strategy=fl.server.strategy.FedAvg(
fraction_fit=0.1, fraction_evaluate=0.1, evaluate_fn=get_evaluate_fn()
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def evaluate(self, parameters, config):
loss = test(net, testloader)
return float(loss), len(testloader), {}

fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=CifarClient())
fl.client.start_client(server_address="127.0.0.1:8080", client=CifarClient().to_client())


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions examples/pytorch-from-centralized-to-federated/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def main() -> None:
_ = model(next(iter(trainloader))["img"].to(DEVICE))

# Start client
client = CifarClient(model, trainloader, testloader)
fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client)
client = CifarClient(model, trainloader, testloader).to_client()
fl.client.start_client(server_address="127.0.0.1:8080", client=client)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions examples/quickstart-fastai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def evaluate(self, parameters, config):


# Start Flower client
fl.client.start_numpy_client(
fl.client.start_client(
server_address="127.0.0.1:8080",
client=FlowerClient(),
client=FlowerClient().to_client(),
)
2 changes: 1 addition & 1 deletion examples/quickstart-huggingface/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def evaluate(self, parameters, config):
return float(loss), len(testloader), {"accuracy": float(accuracy)}

# Start client
fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=IMDBClient())
fl.client.start_client(server_address="127.0.0.1:8080", client=IMDBClient().to_client())


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/quickstart-jax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ def evaluate(


# Start Flower client
fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient())
fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client())
4 changes: 2 additions & 2 deletions examples/quickstart-mlcube/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def main():
os.path.dirname(os.path.abspath(__file__)), "workspaces", workspace_name
)

fl.client.start_numpy_client(
server_address="0.0.0.0:8080", client=MLCubeClient(workspace=workspace)
fl.client.start_client(
server_address="0.0.0.0:8080", client=MLCubeClient(workspace=workspace).to_client()
)


Expand Down
4 changes: 2 additions & 2 deletions examples/quickstart-pandas/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def fit(
X = dataset[column_names]

# Start Flower client
fl.client.start_numpy_client(
fl.client.start_client(
server_address="127.0.0.1:8080",
client=FlowerClient(X),
client=FlowerClient(X).to_client(),
)
4 changes: 2 additions & 2 deletions examples/quickstart-pytorch-lightning/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def main() -> None:
train_loader, val_loader, test_loader = mnist.load_data(node_id)

# Flower client
client = FlowerClient(model, train_loader, val_loader, test_loader)
fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client)
client = FlowerClient(model, train_loader, val_loader, test_loader).to_client()
fl.client.start_client(server_address="127.0.0.1:8080", client=client)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions examples/quickstart-pytorch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def evaluate(self, parameters, config):


# Start Flower client
fl.client.start_numpy_client(
fl.client.start_client(
server_address="127.0.0.1:8080",
client=FlowerClient(),
client=FlowerClient().to_client(),
)
2 changes: 1 addition & 1 deletion examples/quickstart-tabnet/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,4 @@ def evaluate(self, parameters, config):


# Start Flower client
fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=TabNetClient())
fl.client.start_client(server_address="127.0.0.1:8080", client=TabNetClient().to_client())
2 changes: 1 addition & 1 deletion examples/quickstart-tensorflow/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ def evaluate(self, parameters, config):


# Start Flower client
fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=CifarClient())
fl.client.start_client(server_address="127.0.0.1:8080", client=CifarClient().to_client())
2 changes: 1 addition & 1 deletion examples/simulation-pytorch/sim.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@
" valloader = DataLoader(valset.with_transform(apply_transforms), batch_size=32)\n",
"\n",
" # Create and return client\n",
" return FlowerClient(trainloader, valloader)\n",
" return FlowerClient(trainloader, valloader).to_client()\n",
"\n",
" return client_fn\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/simulation-pytorch/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def client_fn(cid: str) -> fl.client.Client:
valset = valset.with_transform(apply_transforms)

# Create and return client
return FlowerClient(trainset, valset)
return FlowerClient(trainset, valset).to_client()

return client_fn

Expand Down
2 changes: 1 addition & 1 deletion examples/simulation-tensorflow/sim.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@
" )\n",
"\n",
" # Create and return client\n",
" return FlowerClient(trainset, valset)\n",
" return FlowerClient(trainset, valset).to_client()\n",
"\n",
" return client_fn\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/simulation-tensorflow/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def client_fn(cid: str) -> fl.client.Client:
)

# Create and return client
return FlowerClient(trainset, valset)
return FlowerClient(trainset, valset).to_client()

return client_fn

Expand Down
2 changes: 1 addition & 1 deletion examples/sklearn-logreg-mnist/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ def evaluate(self, parameters, config): # type: ignore
return loss, len(X_test), {"accuracy": accuracy}

# Start Flower client
fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=MnistClient())
fl.client.start_client(server_address="0.0.0.0:8080", client=MnistClient().to_client())
4 changes: 2 additions & 2 deletions examples/whisper-federated-finetuning/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def client_fn(cid: str):

return WhisperFlowerClient(
full_train_dataset, num_classes, disable_tqdm, compile
)
).to_client()

return client_fn

Expand Down Expand Up @@ -174,7 +174,7 @@ def run_client():
client_data_path=CLIENT_DATA,
)

fl.client.start_numpy_client(
fl.client.start_client(
server_address=f"{args.server_address}:8080", client=client_fn(args.cid)
)

Expand Down
2 changes: 1 addition & 1 deletion examples/xgboost-comprehensive/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,4 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:


# Start Flower client
fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient())
fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient().to_client())
2 changes: 1 addition & 1 deletion examples/xgboost-quickstart/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,4 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:


# Start Flower client
fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient())
fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient().to_client())

0 comments on commit b02f263

Please sign in to comment.