diff --git a/src/py/flwr/server/fleet/grpc_bidi/driver_client_manager.py b/src/py/flwr/server/fleet/grpc_bidi/driver_client_manager.py index e01efb303773..ac53a03379ef 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/driver_client_manager.py +++ b/src/py/flwr/server/fleet/grpc_bidi/driver_client_manager.py @@ -72,7 +72,7 @@ def register(self, client: ClientProxy) -> bool: # Register node in with State state: State = self.state_factory.state() - client.node_id = state.register_node() + client.node_id = state.create_node() # Create and start the instruction scheduler ins_scheduler = InsScheduler( @@ -105,7 +105,7 @@ def unregister(self, client: ClientProxy) -> None: # Unregister node_id in with State state: State = self.state_factory.state() - state.unregister_node(node_id=node_id) + state.delete_node(node_id=node_id) with self._cv: self._cv.notify_all() diff --git a/src/py/flwr/server/fleet/message_handler/message_handler.py b/src/py/flwr/server/fleet/message_handler/message_handler.py index ad418bd1c905..1ee7d48da9a2 100644 --- a/src/py/flwr/server/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/fleet/message_handler/message_handler.py @@ -40,7 +40,7 @@ def create_node( ) -> CreateNodeResponse: """.""" # Register node - node_id = state.register_node() + node_id = state.create_node() return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False)) @@ -51,7 +51,7 @@ def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse: return DeleteNodeResponse() # Update state - state.unregister_node(node_id=request.node.node_id) + state.delete_node(node_id=request.node.node_id) return DeleteNodeResponse() diff --git a/src/py/flwr/server/fleet/message_handler/message_handler_test.py b/src/py/flwr/server/fleet/message_handler/message_handler_test.py index da92b267f082..e5d889adeb3f 100644 --- a/src/py/flwr/server/fleet/message_handler/message_handler_test.py +++ b/src/py/flwr/server/fleet/message_handler/message_handler_test.py @@ -39,8 +39,8 @@ def test_create_node() -> None: create_node(request=request, state=state) # Assert - state.register_node.assert_called_once() - state.unregister_node.assert_not_called() + state.create_node.assert_called_once() + state.delete_node.assert_not_called() state.store_task_ins.assert_not_called() state.get_task_ins.assert_not_called() state.store_task_res.assert_not_called() @@ -57,8 +57,8 @@ def test_delete_node_failure() -> None: delete_node(request=request, state=state) # Assert - state.register_node.assert_not_called() - state.unregister_node.assert_not_called() + state.create_node.assert_not_called() + state.delete_node.assert_not_called() state.store_task_ins.assert_not_called() state.get_task_ins.assert_not_called() state.store_task_res.assert_not_called() @@ -75,8 +75,8 @@ def test_delete_node_success() -> None: delete_node(request=request, state=state) # Assert - state.register_node.assert_not_called() - state.unregister_node.assert_called_once() + state.create_node.assert_not_called() + state.delete_node.assert_called_once() state.store_task_ins.assert_not_called() state.get_task_ins.assert_not_called() state.store_task_res.assert_not_called() @@ -93,8 +93,8 @@ def test_pull_task_ins() -> None: pull_task_ins(request=request, state=state) # Assert - state.register_node.assert_not_called() - state.unregister_node.assert_not_called() + state.create_node.assert_not_called() + state.delete_node.assert_not_called() state.store_task_ins.assert_not_called() state.get_task_ins.assert_called_once() state.store_task_res.assert_not_called() @@ -120,8 +120,8 @@ def test_push_task_res() -> None: push_task_res(request=request, state=state) # Assert - state.register_node.assert_not_called() - state.unregister_node.assert_not_called() + state.create_node.assert_not_called() + state.delete_node.assert_not_called() state.store_task_ins.assert_not_called() state.get_task_ins.assert_not_called() state.store_task_res.assert_called_once() diff --git a/src/py/flwr/server/state/in_memory_state.py b/src/py/flwr/server/state/in_memory_state.py index ca0221769147..228e9d5a6a7f 100644 --- a/src/py/flwr/server/state/in_memory_state.py +++ b/src/py/flwr/server/state/in_memory_state.py @@ -182,7 +182,7 @@ def num_task_res(self) -> int: """ return len(self.task_res_store) - def register_node(self) -> int: + def create_node(self) -> int: """Create, store in state, and return `node_id`.""" # Sample a random 64-bit unsigned integer as node_id node_id: int = random.getrandbits(64) @@ -193,7 +193,7 @@ def register_node(self) -> int: log(ERROR, "Unexpected node registration failure.") return 0 - def unregister_node(self, node_id: int) -> None: + def delete_node(self, node_id: int) -> None: """Unregister a client node.""" if node_id not in self.node_ids: raise ValueError(f"Node {node_id} is not registered") diff --git a/src/py/flwr/server/state/sqlite_state.py b/src/py/flwr/server/state/sqlite_state.py index addded3025f7..f7bd47230dc7 100644 --- a/src/py/flwr/server/state/sqlite_state.py +++ b/src/py/flwr/server/state/sqlite_state.py @@ -471,7 +471,7 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None: return None - def register_node(self) -> int: + def create_node(self) -> int: """Create, store in state, and return `node_id`.""" # Sample a random 64-bit unsigned integer as node_id node_id = random.getrandbits(64) @@ -489,7 +489,7 @@ def register_node(self) -> int: log(ERROR, "Unexpected node registration failure.") return 0 - def unregister_node(self, node_id: int) -> None: + def delete_node(self, node_id: int) -> None: """Remove `node_id` from state.""" sql_node_id = uint64_to_int64(node_id) query = "DELETE FROM node WHERE node_id = :node_id;" diff --git a/src/py/flwr/server/state/state.py b/src/py/flwr/server/state/state.py index b69d851453ce..1e08d9e4f5b7 100644 --- a/src/py/flwr/server/state/state.py +++ b/src/py/flwr/server/state/state.py @@ -132,11 +132,11 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None: """Delete all delivered TaskIns/TaskRes pairs.""" @abc.abstractmethod - def register_node(self) -> int: + def create_node(self) -> int: """Create, store in state, and return `node_id`.""" @abc.abstractmethod - def unregister_node(self, node_id: int) -> None: + def delete_node(self, node_id: int) -> None: """Remove `node_id` from state.""" @abc.abstractmethod diff --git a/src/py/flwr/server/state/state_test.py b/src/py/flwr/server/state/state_test.py index 19973c3bbcb1..d64cdb99ff5f 100644 --- a/src/py/flwr/server/state/state_test.py +++ b/src/py/flwr/server/state/state_test.py @@ -335,7 +335,7 @@ def test_register_node_and_get_nodes(self) -> None: # Execute for _ in range(10): - node_ids.append(state.register_node()) + node_ids.append(state.create_node()) retrieved_node_ids = state.get_nodes(workload_id) # Assert @@ -347,10 +347,10 @@ def test_unregister_node(self) -> None: # Prepare state: State = self.state_factory() workload_id = state.create_workload() - node_id = state.register_node() + node_id = state.create_node() # Execute - state.unregister_node(node_id) + state.delete_node(node_id) retrieved_node_ids = state.get_nodes(workload_id) # Assert @@ -362,7 +362,7 @@ def test_get_nodes_invalid_workload_id(self) -> None: state: State = self.state_factory() state.create_workload() invalid_workload_id = 61016 - state.register_node() + state.create_node() # Execute retrieved_node_ids = state.get_nodes(invalid_workload_id)