Skip to content

Commit

Permalink
Merge branch 'main' into fds-improve-pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Oct 14, 2023
2 parents 22f5bf6 + 1c3aea7 commit bc290a2
Show file tree
Hide file tree
Showing 20 changed files with 147 additions and 119 deletions.
2 changes: 2 additions & 0 deletions baselines/doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ Flower Baselines

Flower Baselines are a collection of organised scripts used to reproduce results from well-known publications or benchmarks. You can check which baselines already exist and/or contribute your own baseline.

.. BASELINES_TABLE_ANCHOR
Tutorials
~~~~~~~~~

Expand Down
46 changes: 46 additions & 0 deletions dev/build-baseline-docs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,48 @@ initial_text=$(cat <<-END
END
)

table_body="\\
.. list-table:: \\
:widths: 15 15 50\\
:header-rows: 1\\
\\
* - Method\\
- Dataset\\
- Tags\\
.. BASELINES_TABLE_ENTRY\\
"


function add_table_entry ()
{
# extract lines from markdown file between --- and ---, preserving newlines and store in variable called metadata
metadata=$(awk '/^---$/{flag=1; next} flag; /^---$/{exit}' $1/README.md)

# get text after "title:" in metadata using sed
title=$(echo "$metadata" | sed -n 's/title: //p')

# get text after "url:" in metadata using sed
url=$(echo "$metadata" | sed -n 's/url: //p')

# get text after "labels:" in metadata using sed
labels=$(echo "$metadata" | sed -n 's/labels: //p' | sed 's/\[//g; s/\]//g')

# get text after "dataset:" in metadata using sed
dataset=$(echo "$metadata" | sed -n 's/dataset: //p' | sed 's/\[//g; s/\]//g')

table_entry="\\
* - \`$1 <$1.html>\`_\\
- $dataset\\
- $labels\\
\\
.. BASELINES_TABLE_ENTRY\
"
}


# Create Sphinx table block and header
! sed -i '' -e "s/.. BASELINES_TABLE_ANCHOR/$table_body/" $INDEX

! grep ":caption: References" $INDEX && echo "$initial_text" >> $INDEX && echo "" >> $INDEX

rm -f "baselines/doc/source/*.md"
Expand Down Expand Up @@ -46,6 +88,10 @@ for d in $(printf '%s\n' */ | sort -V); do
# For each baseline, insert the name of the baseline into the index file
echo " $baseline" >> $INDEX

# Add entry to the table
add_table_entry $baseline
! sed -i '' -e "s/.. BASELINES_TABLE_ENTRY/$table_entry/" $INDEX

fi
fi
done
Expand Down
8 changes: 4 additions & 4 deletions doc/source/ref-changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@

The types of the return values in the docstrings in two methods (`aggregate_fit` and `aggregate_evaluate`) now match the hint types in the code.

- **Unify client API** ([#2303](https://github.com/adap/flower/pull/2303))
- **Unify client API** ([#2303](https://github.com/adap/flower/pull/2303), [#2390](https://github.com/adap/flower/pull/2390), [#2493](https://github.com/adap/flower/pull/2493))

Using the `client_fn`, Flower clients can interchangeably run as standalone processes (i.e. via `start_client`) or in simulation (i.e. via `start_simulation`) without requiring changes to how the client class is defined and instantiated.
Using the `client_fn`, Flower clients can interchangeably run as standalone processes (i.e. via `start_client`) or in simulation (i.e. via `start_simulation`) without requiring changes to how the client class is defined and instantiated. Calling `start_numpy_client` is now deprecated.

- **Update Flower Baselines**

- FedProx ([#2210](https://github.com/adap/flower/pull/2210), [#2286](https://github.com/adap/flower/pull/2286))

- Baselines Docs ([#2290](https://github.com/adap/flower/pull/2290))
- Baselines Docs ([#2290](https://github.com/adap/flower/pull/2290), [#2400](https://github.com/adap/flower/pull/2400))

- **Update Flower Examples** ([#2384](https://github.com/adap/flower/pull/2384)), ([#2425](https://github.com/adap/flower/pull/2425))

- **General updates to baselines** ([#2301](https://github.com/adap/flower/pull/2301), [#2305](https://github.com/adap/flower/pull/2305), [#2307](https://github.com/adap/flower/pull/2307), [#2327](https://github.com/adap/flower/pull/2327))

- **General updates to the simulation engine** ([#2331](https://github.com/adap/flower/pull/2331), [#2447](https://github.com/adap/flower/pull/2447), [#2448](https://github.com/adap/flower/pull/2448))

- **General improvements** ([#2309](https://github.com/adap/flower/pull/2309), [#2310](https://github.com/adap/flower/pull/2310), [2313](https://github.com/adap/flower/pull/2313), [#2316](https://github.com/adap/flower/pull/2316), [2317](https://github.com/adap/flower/pull/2317),[#2349](https://github.com/adap/flower/pull/2349), [#2360](https://github.com/adap/flower/pull/2360), [#2402](https://github.com/adap/flower/pull/2402), [#2446](https://github.com/adap/flower/pull/2446), [#2493](https://github.com/adap/flower/pull/2493))
- **General improvements** ([#2309](https://github.com/adap/flower/pull/2309), [#2310](https://github.com/adap/flower/pull/2310), [2313](https://github.com/adap/flower/pull/2313), [#2316](https://github.com/adap/flower/pull/2316), [2317](https://github.com/adap/flower/pull/2317),[#2349](https://github.com/adap/flower/pull/2349), [#2360](https://github.com/adap/flower/pull/2360), [#2402](https://github.com/adap/flower/pull/2402), [#2446](https://github.com/adap/flower/pull/2446))

Flower received many improvements under the hood, too many to list here.

Expand Down
4 changes: 2 additions & 2 deletions examples/mt-pytorch/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import random
import time

from flwr.driver import Driver
from flwr.driver import GrpcDriver
from flwr.common import (
ServerMessage,
FitIns,
Expand Down Expand Up @@ -43,7 +43,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:


# -------------------------------------------------------------------------- Driver SDK
driver = Driver(driver_service_address="0.0.0.0:9091", certificates=None)
driver = GrpcDriver(driver_service_address="0.0.0.0:9091", certificates=None)
# -------------------------------------------------------------------------- Driver SDK

anonymous_client_nodes = False
Expand Down
4 changes: 2 additions & 2 deletions examples/secaggplus-mt/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from workflows import get_workflow_factory

from flwr.common import Metrics, ndarrays_to_parameters
from flwr.driver import Driver
from flwr.driver import GrpcDriver
from flwr.proto import driver_pb2, node_pb2, task_pb2
from flwr.server import History

Expand Down Expand Up @@ -71,7 +71,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:


# -------------------------------------------------------------------------- Driver SDK
driver = Driver(driver_service_address="0.0.0.0:9091", certificates=None)
driver = GrpcDriver(driver_service_address="0.0.0.0:9091", certificates=None)
# -------------------------------------------------------------------------- Driver SDK

anonymous_client_nodes = False
Expand Down
4 changes: 0 additions & 4 deletions src/py/flwr/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,14 @@
from .app import start_numpy_client as start_numpy_client
from .client import Client as Client
from .numpy_client import NumPyClient as NumPyClient
from .numpy_client_wrapper import to_client as to_client
from .run import run_client as run_client
from .typing import ClientFn as ClientFn
from .typing import ClientLike as ClientLike

__all__ = [
"Client",
"ClientFn",
"ClientLike",
"NumPyClient",
"run_client",
"start_client",
"start_numpy_client",
"to_client",
]
54 changes: 25 additions & 29 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

import sys
import time
import warnings
from logging import INFO
from typing import Callable, Optional, Union
from typing import Optional, Union

from flwr.client.client import Client
from flwr.client.typing import ClientFn, ClientLike
from flwr.client.typing import ClientFn
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
from flwr.common.address import parse_address
from flwr.common.constant import (
Expand All @@ -40,7 +41,7 @@


def _check_actionable_client(
client: Optional[ClientLike], client_fn: Optional[ClientFn]
client: Optional[Client], client_fn: Optional[ClientFn]
) -> None:
if client_fn is None and client is None:
raise Exception("Both `client_fn` and `client` are `None`, but one is required")
Expand All @@ -57,7 +58,7 @@ def start_client(
*,
server_address: str,
client_fn: Optional[ClientFn] = None,
client: Optional[ClientLike] = None,
client: Optional[Client] = None,
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
root_certificates: Optional[Union[bytes, str]] = None,
transport: Optional[str] = None,
Expand Down Expand Up @@ -124,7 +125,7 @@ class `flwr.client.Client` (default: None)
# Wrap `Client` instance in `client_fn`
def single_client_factory(
cid: str, # pylint: disable=unused-argument
) -> ClientLike:
) -> Client:
if client is None: # Added this to keep mypy happy
raise Exception(
"Both `client_fn` and `client` are `None`, but one is required"
Expand Down Expand Up @@ -209,8 +210,7 @@ def single_client_factory(
def start_numpy_client(
*,
server_address: str,
client_fn: Optional[Callable[[str], NumPyClient]] = None,
client: Optional[NumPyClient] = None,
client: NumPyClient,
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
root_certificates: Optional[bytes] = None,
transport: Optional[str] = None,
Expand All @@ -223,9 +223,7 @@ def start_numpy_client(
The IPv4 or IPv6 address of the server. If the Flower server runs on
the same machine on port 8080, then `server_address` would be
`"[::]:8080"`.
client_fn : Optional[Callable[[str], NumPyClient]]
A callable that instantiates a NumPyClient. (default: None)
client : Optional[flwr.client.NumPyClient]
client : flwr.client.NumPyClient
An implementation of the abstract base class `flwr.client.NumPyClient`.
grpc_max_message_length : int (default: 536_870_912, this equals 512MB)
The maximum length of gRPC messages that can be exchanged with the
Expand All @@ -248,42 +246,40 @@ def start_numpy_client(
--------
Starting a client with an insecure server connection:
>>> def client_fn(cid: str):
>>> return FlowerClient()
>>>
>>> start_numpy_client(
>>> server_address=localhost:8080,
>>> client_fn=client_fn,
>>> client=FlowerClient(),
>>> )
Starting an SSL-enabled gRPC client:
>>> from pathlib import Path
>>> def client_fn(cid: str):
>>> return FlowerClient()
>>>
>>> start_numpy_client(
>>> server_address=localhost:8080,
>>> client_fn=client_fn,
>>> client=FlowerClient(),
>>> root_certificates=Path("/crts/root.pem").read_bytes(),
>>> )
"""
# Start
_check_actionable_client(client, client_fn)

wrp_client = client.to_client() if client else None
wrp_clientfn = None
if client_fn:
warnings.warn(
"flwr.client.start_numpy_client() is deprecated and will "
"be removed in a future version of Flower. Instead, pass "
"your client to `flwr.client.start_client()` by calling "
"first the `.to_client()` method as shown below: \n"
"\tflwr.client.start_client(\n"
"\t\tserver_address='<IP>:<PORT>',\n"
"\t\tclient=FlowerClient().to_client()\n"
"\t)",
DeprecationWarning,
stacklevel=2,
)

def convert(cid: str) -> Client:
"""Convert `NumPyClient` to `Client` upon instantiation."""
return client_fn(cid).to_client()
# Calling this function is deprecated. A warning is thrown.
# We first need to convert either the supplied client to `Client.`

wrp_clientfn = convert
wrp_client = client.to_client()

start_client(
server_address=server_address,
client_fn=wrp_clientfn,
client=wrp_client,
grpc_max_message_length=grpc_max_message_length,
root_certificates=root_certificates,
Expand Down
17 changes: 4 additions & 13 deletions src/py/flwr/client/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from typing import Dict, Tuple

from flwr.client import ClientLike, to_client
from flwr.common import (
Config,
EvaluateIns,
Expand Down Expand Up @@ -83,26 +82,18 @@ def evaluate(

def test_to_client_with_client() -> None:
"""Test to_client."""
# Prepare
client_like: ClientLike = PlainClient()

# Execute
actual = to_client(client_like=client_like)
client = PlainClient().to_client()

# Assert
assert isinstance(actual, Client)
assert isinstance(client, Client)


def test_to_client_with_numpyclient() -> None:
"""Test fit_clients."""
# Prepare
client_like: ClientLike = NeedsWrappingClient()

# Execute
actual = to_client(client_like=client_like)
client = NeedsWrappingClient().to_client()

# Assert
assert isinstance(actual, Client)
assert isinstance(client, Client)


def test_start_client_transport_invalid() -> None:
Expand Down
9 changes: 3 additions & 6 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@
get_server_message_from_task_ins,
wrap_client_message_in_task_res,
)
from flwr.client.numpy_client_wrapper import to_client
from flwr.client.secure_aggregation import SecureAggregationHandler
from flwr.client.typing import ClientFn, ClientLike
from flwr.client.typing import ClientFn
from flwr.common import serde
from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes
from flwr.proto.transport_pb2 import ClientMessage, Reason, ServerMessage
Expand Down Expand Up @@ -64,8 +63,7 @@ def handle(client_fn: ClientFn, task_ins: TaskIns) -> Tuple[TaskRes, int, bool]:
server_msg = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False)
if server_msg is None:
# Instantiate the client
client_like: ClientLike = client_fn("-1")
client = to_client(client_like)
client = client_fn("-1")
# Secure Aggregation
if task_ins.task.HasField("sa") and isinstance(
client, SecureAggregationHandler
Expand Down Expand Up @@ -120,8 +118,7 @@ def handle_legacy_message(
return disconnect_msg, sleep_duration, False

# Instantiate the client
client_like: ClientLike = client_fn("-1")
client = to_client(client_like)
client = client_fn("-1")
# Execute task
if field == "get_properties_ins":
return _get_properties(client, server_msg.get_properties_ins), 0, True
Expand Down
27 changes: 0 additions & 27 deletions src/py/flwr/client/numpy_client_wrapper.py

This file was deleted.

6 changes: 2 additions & 4 deletions src/py/flwr/client/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
# ==============================================================================
"""Custom types for Flower clients."""

from typing import Callable, Union
from typing import Callable

from .client import Client as Client
from .numpy_client import NumPyClient as NumPyClient

ClientLike = Union[Client, NumPyClient]
ClientFn = Callable[[str], ClientLike]
ClientFn = Callable[[str], Client]
4 changes: 2 additions & 2 deletions src/py/flwr/driver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@


from .app import start_driver
from .driver import Driver
from .driver import GrpcDriver

__all__ = [
"start_driver",
"Driver",
"GrpcDriver",
]
Loading

0 comments on commit bc290a2

Please sign in to comment.