Skip to content

Commit

Permalink
Add NodeState to RayClientProxy (#2686)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Dec 7, 2023
1 parent 8b44a63 commit 17f488a
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 8 deletions.
1 change: 0 additions & 1 deletion src/py/flwr/simulation/ray_transport/ray_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def run(
job_results = job_fn(client)
# Retrieve state (potentially updated)
updated_state = client.get_state()
print(f"Actor finishing ({cid}) !!!: {updated_state = }")
except Exception as ex:
client_trace = traceback.format_exc()
message = (
Expand Down
23 changes: 20 additions & 3 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
maybe_call_get_parameters,
maybe_call_get_properties,
)
from flwr.client.workload_state import WorkloadState
from flwr.client.node_state import NodeState
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy
from flwr.simulation.ray_transport.ray_actor import (
Expand Down Expand Up @@ -129,14 +129,31 @@ def __init__(
super().__init__(cid)
self.client_fn = client_fn
self.actor_pool = actor_pool
self.proxy_state = NodeState()

def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes:
# The VCE is not exposed to TaskIns, it won't handle multilple workloads
# For the time being, fixing workload_id is a small compromise
# This will be one of the first points to address integrating VCE + DriverAPI
workload_id = 0

# Register state
self.proxy_state.register_workloadstate(workload_id=workload_id)

# Retrieve state
state = self.proxy_state.retrieve_workloadstate(workload_id=workload_id)

try:
self.actor_pool.submit_client_job(
lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state),
(self.client_fn, job_fn, self.cid, WorkloadState(state={})),
(self.client_fn, job_fn, self.cid, state),
)
res, updated_state = self.actor_pool.get_client_result(self.cid, timeout)

# Update state
self.proxy_state.update_workloadstate(
workload_id=workload_id, workload_state=updated_state
)
res, _ = self.actor_pool.get_client_result(self.cid, timeout)

except Exception as ex:
if self.actor_pool.num_actors == 0:
Expand Down
23 changes: 19 additions & 4 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def job_fn(cid: str) -> JobFn: # pragma: no cover
def cid_times_pi(client: Client) -> ClientRes: # pylint: disable=unused-argument
result = int(cid) * pi

# store something in state
client.numpy_client.state.state["result"] = str(result) # type: ignore

# now let's convert it to a GetPropertiesRes response
return GetPropertiesRes(
status=Status(Code(0), message="test"), properties={"result": result}
Expand Down Expand Up @@ -109,28 +112,40 @@ def test_cid_consistency_one_at_a_time() -> None:
ray.shutdown()


def test_cid_consistency_all_submit_first() -> None:
def test_cid_consistency_all_submit_first_workload_consistency() -> None:
"""Test that ClientProxies get the result of client job they submit.
All jobs are submitted at the same time. Then fetched one at a time.
All jobs are submitted at the same time. Then fetched one at a time. This also tests
NodeState (at each Proxy) and WorkloadState basic functionality.
"""
proxies, _ = prep()
workload_id = 0

# submit all jobs (collect later)
shuffle(proxies)
for prox in proxies:
# Register state
prox.proxy_state.register_workloadstate(workload_id=workload_id)
# Retrieve state
state = prox.proxy_state.retrieve_workloadstate(workload_id=workload_id)

job = job_fn(prox.cid)
prox.actor_pool.submit_client_job(
lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state),
(prox.client_fn, job, prox.cid, WorkloadState(state={})),
(prox.client_fn, job, prox.cid, state),
)

# fetch results one at a time
shuffle(proxies)
for prox in proxies:
res, _ = prox.actor_pool.get_client_result(prox.cid, timeout=None)
res, updated_state = prox.actor_pool.get_client_result(prox.cid, timeout=None)
prox.proxy_state.update_workloadstate(workload_id, workload_state=updated_state)
res = cast(GetPropertiesRes, res)
assert int(prox.cid) * pi == res.properties["result"]
assert (
str(int(prox.cid) * pi)
== prox.proxy_state.retrieve_workloadstate(workload_id).state["result"]
)

ray.shutdown()

Expand Down

0 comments on commit 17f488a

Please sign in to comment.