Skip to content

Commit

Permalink
Merge branch 'main' into update-flower-server
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Nov 30, 2023
2 parents 6f6a5e3 + eaa7ff9 commit a8a1eeb
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 21 deletions.
2 changes: 2 additions & 0 deletions doc/source/ref-changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ We would like to give our special thanks to all the contributors who made the ne

- **Introduce `WorkloadState`** ([#2564](https://github.com/adap/flower/pull/2564), [#2632](https://github.com/adap/flower/pull/2632))

- **Introduce `WorkloadState`** ([#2564](https://github.com/adap/flower/pull/2564), [#2632](https://github.com/adap/flower/pull/2632))

- **Update Flower Baselines**

- FedProx ([#2210](https://github.com/adap/flower/pull/2210), [#2286](https://github.com/adap/flower/pull/2286), [#2509](https://github.com/adap/flower/pull/2509))
Expand Down
38 changes: 26 additions & 12 deletions src/py/flwr/simulation/ray_transport/ray_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from flwr import common
from flwr.client import Client, ClientFn
from flwr.client.workload_state import WorkloadState
from flwr.common.logger import log
from flwr.simulation.ray_transport.utils import check_clientfn_returns_client

Expand Down Expand Up @@ -60,16 +61,22 @@ def run(
client_fn: ClientFn,
job_fn: JobFn,
cid: str,
) -> Tuple[str, ClientRes]:
state: WorkloadState,
) -> Tuple[str, ClientRes, WorkloadState]:
"""Run a client workload."""
# Execute tasks and return result
# return also cid which is needed to ensure results
# from the pool are correctly assigned to each ClientProxy
try:
# Instantiate client (check 'Client' type is returned)
client = check_clientfn_returns_client(client_fn(cid))
# Inject state
client.set_state(state)
# Run client job
job_results = job_fn(client)
# Retrieve state (potentially updated)
updated_state = client.get_state()
print(f"Actor finishing ({cid}) !!!: {updated_state = }")
except Exception as ex:
client_trace = traceback.format_exc()
message = (
Expand All @@ -83,7 +90,7 @@ def run(
)
raise ClientException(str(message)) from ex

return cid, job_results
return cid, job_results, updated_state


@ray.remote
Expand Down Expand Up @@ -231,16 +238,18 @@ def add_actors_to_pool(self, num_actors: int) -> None:
self._idle_actors.extend(new_actors)
self.num_actors += num_actors

def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str]) -> None:
def submit(
self, fn: Any, value: Tuple[ClientFn, JobFn, str, WorkloadState]
) -> None:
"""Take idle actor and assign it a client workload.
Submit a job to an actor by first removing it from the list of idle actors, then
check if this actor was flagged to be removed from the pool
"""
client_fn, job_fn, cid = value
client_fn, job_fn, cid, state = value
actor = self._idle_actors.pop()
if self._check_and_remove_actor_from_pool(actor):
future = fn(actor, client_fn, job_fn, cid)
future = fn(actor, client_fn, job_fn, cid, state)
future_key = tuple(future) if isinstance(future, List) else future
self._future_to_actor[future_key] = (self._next_task_index, actor, cid)
self._next_task_index += 1
Expand All @@ -249,10 +258,10 @@ def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str]) -> None:
self._cid_to_future[cid]["future"] = future_key

def submit_client_job(
self, actor_fn: Any, job: Tuple[ClientFn, JobFn, str]
self, actor_fn: Any, job: Tuple[ClientFn, JobFn, str, WorkloadState]
) -> None:
"""Submit a job while tracking client ids."""
_, _, cid = job
_, _, cid, _ = job

# We need to put this behind a lock since .submit() involves
# removing and adding elements from a dictionary. Which creates
Expand Down Expand Up @@ -289,15 +298,17 @@ def _is_future_ready(self, cid: str) -> bool:

return self._cid_to_future[cid]["ready"] # type: ignore

def _fetch_future_result(self, cid: str) -> ClientRes:
"""Fetch result for VirtualClient from Object Store.
def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, WorkloadState]:
"""Fetch result and updated state for a VirtualClient from Object Store.
The job submitted by the ClientProxy interfacing with client with cid=cid is
ready. Here we fetch it from the object store and return.
"""
try:
future: ObjectRef[Any] = self._cid_to_future[cid]["future"] # type: ignore
res_cid, res = ray.get(future) # type: (str, ClientRes)
res_cid, res, updated_state = ray.get(
future
) # type: (str, ClientRes, WorkloadState)
except ray.exceptions.RayActorError as ex:
log(ERROR, ex)
if hasattr(ex, "actor_id"):
Expand All @@ -314,7 +325,7 @@ def _fetch_future_result(self, cid: str) -> ClientRes:
# Reset mapping
self._reset_cid_to_future_dict(cid)

return res
return res, updated_state

def _flag_actor_for_removal(self, actor_id_hex: str) -> None:
"""Flag actor that should be removed from pool."""
Expand Down Expand Up @@ -399,7 +410,9 @@ def process_unordered_future(self, timeout: Optional[float] = None) -> None:
# Manually terminate the actor
actor.terminate.remote()

def get_client_result(self, cid: str, timeout: Optional[float]) -> ClientRes:
def get_client_result(
self, cid: str, timeout: Optional[float]
) -> Tuple[ClientRes, WorkloadState]:
"""Get result from VirtualClient with specific cid."""
# Loop until all jobs submitted to the pool are completed. Break early
# if the result for the ClientProxy calling this method is ready
Expand All @@ -411,4 +424,5 @@ def get_client_result(self, cid: str, timeout: Optional[float]) -> ClientRes:
break

# Fetch result belonging to the VirtualClient calling this method
# Return both result from tasks and (potentially) updated workload state
return self._fetch_future_result(cid)
7 changes: 4 additions & 3 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
maybe_call_get_parameters,
maybe_call_get_properties,
)
from flwr.client.workload_state import WorkloadState
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy
from flwr.simulation.ray_transport.ray_actor import (
Expand Down Expand Up @@ -132,10 +133,10 @@ def __init__(
def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes:
try:
self.actor_pool.submit_client_job(
lambda a, c_fn, j_fn, cid: a.run.remote(c_fn, j_fn, cid),
(self.client_fn, job_fn, self.cid),
lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state),
(self.client_fn, job_fn, self.cid, WorkloadState(state={})),
)
res = self.actor_pool.get_client_result(self.cid, timeout)
res, _ = self.actor_pool.get_client_result(self.cid, timeout)

except Exception as ex:
if self.actor_pool.num_actors == 0:
Expand Down
13 changes: 7 additions & 6 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import ray

from flwr.client import Client, NumPyClient
from flwr.client.workload_state import WorkloadState
from flwr.common import Code, GetPropertiesRes, Status
from flwr.simulation.ray_transport.ray_actor import (
ClientRes,
Expand Down Expand Up @@ -120,14 +121,14 @@ def test_cid_consistency_all_submit_first() -> None:
for prox in proxies:
job = job_fn(prox.cid)
prox.actor_pool.submit_client_job(
lambda a, c_fn, j_fn, cid: a.run.remote(c_fn, j_fn, cid),
(prox.client_fn, job, prox.cid),
lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state),
(prox.client_fn, job, prox.cid, WorkloadState(state={})),
)

# fetch results one at a time
shuffle(proxies)
for prox in proxies:
res = prox.actor_pool.get_client_result(prox.cid, timeout=None)
res, _ = prox.actor_pool.get_client_result(prox.cid, timeout=None)
res = cast(GetPropertiesRes, res)
assert int(prox.cid) * pi == res.properties["result"]

Expand All @@ -145,14 +146,14 @@ def test_cid_consistency_without_proxies() -> None:
for cid in cids:
job = job_fn(cid)
pool.submit_client_job(
lambda a, c_fn, j_fn, cid_: a.run.remote(c_fn, j_fn, cid_),
(get_dummy_client, job, cid),
lambda a, c_fn, j_fn, cid_, state: a.run.remote(c_fn, j_fn, cid_, state),
(get_dummy_client, job, cid, WorkloadState(state={})),
)

# fetch results one at a time
shuffle(cids)
for cid in cids:
res = pool.get_client_result(cid, timeout=None)
res, _ = pool.get_client_result(cid, timeout=None)
res = cast(GetPropertiesRes, res)
assert int(cid) * pi == res.properties["result"]

Expand Down

0 comments on commit a8a1eeb

Please sign in to comment.