Skip to content

Commit

Permalink
Make flower-client HTTPS by default (#2636)
Browse files Browse the repository at this point in the history
Co-authored-by: Taner Topal <[email protected]>
Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
3 people authored Nov 28, 2023
1 parent 63462fa commit 875b6c7
Show file tree
Hide file tree
Showing 21 changed files with 169 additions and 33 deletions.
6 changes: 5 additions & 1 deletion e2e/bare-https/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,17 @@ def evaluate(self, parameters, config):
return loss, 1, {"accuracy": accuracy}

def client_fn(cid):
return FlowerClient()
return FlowerClient().to_client()

flower = fl.flower.Flower(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
fl.client.start_numpy_client(
server_address="127.0.0.1:8080",
client=FlowerClient(),
root_certificates=Path("certificates/ca.crt").read_bytes(),
insecure=False,
)
5 changes: 4 additions & 1 deletion e2e/bare/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ def evaluate(self, parameters, config):
return loss, 1, {"accuracy": accuracy}

def client_fn(cid):
return FlowerClient()
return FlowerClient().to_client()

flower = fl.flower.Flower(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
Expand Down
7 changes: 6 additions & 1 deletion e2e/fastai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ def evaluate(self, parameters, config):


def client_fn(cid):
return FlowerClient()
return FlowerClient().to_client()


flower = fl.flower.Flower(
client_fn=client_fn,
)


if __name__ == "__main__":
Expand Down
6 changes: 5 additions & 1 deletion e2e/jax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ def evaluate(
return float(loss), num_examples, {"loss": float(loss)}

def client_fn(cid):
return FlowerClient()
return FlowerClient().to_client()

flower = fl.flower.Flower(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
Expand Down
6 changes: 5 additions & 1 deletion e2e/mxnet/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,11 @@ def evaluate(self, parameters, config):


def client_fn(cid):
return FlowerClient()
return FlowerClient().to_client()

flower = fl.flower.Flower(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
Expand Down
6 changes: 5 additions & 1 deletion e2e/opacus/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,11 @@ def evaluate(self, parameters, config):

def client_fn(cid):
model = Net()
return FlowerClient(model)
return FlowerClient(model).to_client()

flower = fl.flower.Flower(
client_fn=client_fn,
)

if __name__ == "__main__":
fl.client.start_numpy_client(
Expand Down
6 changes: 5 additions & 1 deletion e2e/pandas/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ def fit(
)

def client_fn(cid):
return FlowerClient()
return FlowerClient().to_client()

flower = fl.flower.Flower(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
Expand Down
6 changes: 5 additions & 1 deletion e2e/pytorch-lightning/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ def client_fn(cid):
train_loader, val_loader, test_loader = mnist.load_data()

# Flower client
return FlowerClient(model, train_loader, val_loader, test_loader)
return FlowerClient(model, train_loader, val_loader, test_loader).to_client()

flower = fl.flower.Flower(
client_fn=client_fn,
)

def main() -> None:
# Model and data
Expand Down
6 changes: 5 additions & 1 deletion e2e/pytorch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ def set_parameters(model, parameters):
return

def client_fn(cid):
return FlowerClient()
return FlowerClient().to_client()

flower = fl.flower.Flower(
client_fn=client_fn,
)


if __name__ == "__main__":
Expand Down
6 changes: 5 additions & 1 deletion e2e/scikit-learn/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ def evaluate(self, parameters, config): # type: ignore
return loss, len(X_test), {"accuracy": accuracy}

def client_fn(cid):
return FlowerClient()
return FlowerClient().to_client()

flower = fl.flower.Flower(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
Expand Down
9 changes: 9 additions & 0 deletions e2e/strategies/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ def evaluate(self, parameters, config):
return loss, len(x_test), {"accuracy": accuracy}


def client_fn(cid):
return FlowerClient().to_client()


flower = fl.flower.Flower(
client_fn=client_fn,
)


if __name__ == "__main__":
# Start Flower client
fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient())
6 changes: 5 additions & 1 deletion e2e/tabnet/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ def evaluate(self, parameters, config):


def client_fn(cid):
return FlowerClient()
return FlowerClient().to_client()

flower = fl.flower.Flower(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
Expand Down
6 changes: 5 additions & 1 deletion e2e/tensorflow/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ def evaluate(self, parameters, config):
return loss, len(x_test), {"accuracy": accuracy}

def client_fn(cid):
return FlowerClient()
return FlowerClient().to_client()

flower = fl.flower.Flower(
client_fn=client_fn,
)

if __name__ == "__main__":
# Start Flower client
Expand Down
14 changes: 8 additions & 6 deletions e2e/test_driver.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@ set -e
case "$1" in
bare-https)
./generate.sh
cert_arg="--certificates certificates/ca.crt certificates/server.pem certificates/server.key"
server_arg="--certificates certificates/ca.crt certificates/server.pem certificates/server.key"
client_arg="--root-certificates certificates/ca.crt"
;;
*)
cert_arg="--insecure"
server_arg="--insecure"
client_arg="--insecure"
;;
esac

timeout 2m flower-server $cert_arg --grpc-bidi --grpc-bidi-fleet-api-address 0.0.0.0:8080 &
timeout 2m flower-server $server_arg &
sleep 3

python client.py &
timeout 2m flower-client $client_arg --callable client:flower --server 127.0.0.1:9092 &
sleep 3

python client.py &
timeout 2m flower-client $client_arg --callable client:flower --server 127.0.0.1:9092 &
sleep 3

timeout 2m python driver.py &
Expand All @@ -27,7 +29,7 @@ wait $pid
res=$?

if [[ "$res" = "0" ]];
then echo "Training worked correctly" && pkill python;
then echo "Training worked correctly" && pkill flower-client && pkill flower-server;
else echo "Training had an issue" && exit 1;
fi

Loading

0 comments on commit 875b6c7

Please sign in to comment.