From 17f488a6f4399f6a3eea175800afad404f5d960f Mon Sep 17 00:00:00 2001 From: Javier Date: Thu, 7 Dec 2023 12:56:23 +0000 Subject: [PATCH] Add `NodeState` to `RayClientProxy` (#2686) --- .../simulation/ray_transport/ray_actor.py | 1 - .../ray_transport/ray_client_proxy.py | 23 ++++++++++++++++--- .../ray_transport/ray_client_proxy_test.py | 23 +++++++++++++++---- 3 files changed, 39 insertions(+), 8 deletions(-) diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 57ea0ed7b187..640817910396 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -76,7 +76,6 @@ def run( job_results = job_fn(client) # Retrieve state (potentially updated) updated_state = client.get_state() - print(f"Actor finishing ({cid}) !!!: {updated_state = }") except Exception as ex: client_trace = traceback.format_exc() message = ( diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 9596acaf1d91..c6a63298dae6 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -29,7 +29,7 @@ maybe_call_get_parameters, maybe_call_get_properties, ) -from flwr.client.workload_state import WorkloadState +from flwr.client.node_state import NodeState from flwr.common.logger import log from flwr.server.client_proxy import ClientProxy from flwr.simulation.ray_transport.ray_actor import ( @@ -129,14 +129,31 @@ def __init__( super().__init__(cid) self.client_fn = client_fn self.actor_pool = actor_pool + self.proxy_state = NodeState() def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: + # The VCE is not exposed to TaskIns, it won't handle multilple workloads + # For the time being, fixing workload_id is a small compromise + # This will be one of the first points to address integrating VCE + DriverAPI + workload_id = 0 + + # Register state + self.proxy_state.register_workloadstate(workload_id=workload_id) + + # Retrieve state + state = self.proxy_state.retrieve_workloadstate(workload_id=workload_id) + try: self.actor_pool.submit_client_job( lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state), - (self.client_fn, job_fn, self.cid, WorkloadState(state={})), + (self.client_fn, job_fn, self.cid, state), + ) + res, updated_state = self.actor_pool.get_client_result(self.cid, timeout) + + # Update state + self.proxy_state.update_workloadstate( + workload_id=workload_id, workload_state=updated_state ) - res, _ = self.actor_pool.get_client_result(self.cid, timeout) except Exception as ex: if self.actor_pool.num_actors == 0: diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 44cb4ec70471..b87418b671d3 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -53,6 +53,9 @@ def job_fn(cid: str) -> JobFn: # pragma: no cover def cid_times_pi(client: Client) -> ClientRes: # pylint: disable=unused-argument result = int(cid) * pi + # store something in state + client.numpy_client.state.state["result"] = str(result) # type: ignore + # now let's convert it to a GetPropertiesRes response return GetPropertiesRes( status=Status(Code(0), message="test"), properties={"result": result} @@ -109,28 +112,40 @@ def test_cid_consistency_one_at_a_time() -> None: ray.shutdown() -def test_cid_consistency_all_submit_first() -> None: +def test_cid_consistency_all_submit_first_workload_consistency() -> None: """Test that ClientProxies get the result of client job they submit. - All jobs are submitted at the same time. Then fetched one at a time. + All jobs are submitted at the same time. Then fetched one at a time. This also tests + NodeState (at each Proxy) and WorkloadState basic functionality. """ proxies, _ = prep() + workload_id = 0 # submit all jobs (collect later) shuffle(proxies) for prox in proxies: + # Register state + prox.proxy_state.register_workloadstate(workload_id=workload_id) + # Retrieve state + state = prox.proxy_state.retrieve_workloadstate(workload_id=workload_id) + job = job_fn(prox.cid) prox.actor_pool.submit_client_job( lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state), - (prox.client_fn, job, prox.cid, WorkloadState(state={})), + (prox.client_fn, job, prox.cid, state), ) # fetch results one at a time shuffle(proxies) for prox in proxies: - res, _ = prox.actor_pool.get_client_result(prox.cid, timeout=None) + res, updated_state = prox.actor_pool.get_client_result(prox.cid, timeout=None) + prox.proxy_state.update_workloadstate(workload_id, workload_state=updated_state) res = cast(GetPropertiesRes, res) assert int(prox.cid) * pi == res.properties["result"] + assert ( + str(int(prox.cid) * pi) + == prox.proxy_state.retrieve_workloadstate(workload_id).state["result"] + ) ray.shutdown()