Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move node_id generation into State (#2401) #2401

Merged
merged 14 commits into from
Oct 7, 2023
15 changes: 5 additions & 10 deletions src/py/flwr/server/fleet/grpc_bidi/driver_client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Flower DriverClientManager."""


import random
import threading
from typing import Dict, List, Optional, Set, Tuple

Expand Down Expand Up @@ -71,13 +70,9 @@ def register(self, client: ClientProxy) -> bool:
if client.cid in self.nodes:
return False

# Generate random integer ID
random_node_id: int = random.randrange(9223372036854775808)
client.node_id = random_node_id

# Register node_id in with State
# Create node in State
state: State = self.state_factory.state()
state.register_node(node_id=random_node_id)
client.node_id = state.create_node()

# Create and start the instruction scheduler
ins_scheduler = InsScheduler(
Expand All @@ -87,7 +82,7 @@ def register(self, client: ClientProxy) -> bool:
ins_scheduler.start()

# Store cid, node_id, and InsScheduler
self.nodes[client.cid] = (random_node_id, ins_scheduler)
self.nodes[client.cid] = (client.node_id, ins_scheduler)

with self._cv:
self._cv.notify_all()
Expand All @@ -108,9 +103,9 @@ def unregister(self, client: ClientProxy) -> None:
del self.nodes[client.cid]
ins_scheduler.stop()

# Unregister node_id in with State
# Delete node_id in State
state: State = self.state_factory.state()
state.unregister_node(node_id=node_id)
state.delete_node(node_id=node_id)

with self._cv:
self._cv.notify_all()
Expand Down
12 changes: 4 additions & 8 deletions src/py/flwr/server/fleet/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Fleet API message handlers."""


import random
from typing import List, Optional
from uuid import UUID

Expand All @@ -40,12 +39,9 @@ def create_node(
state: State,
) -> CreateNodeResponse:
"""."""
# Generate random node_id
random_node_id: int = random.randrange(9223372036854775808)

# Update state
state.register_node(node_id=random_node_id)
return CreateNodeResponse(node=Node(node_id=random_node_id, anonymous=False))
# Create node
node_id = state.create_node()
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))


def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:
Expand All @@ -55,7 +51,7 @@ def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:
return DeleteNodeResponse()

# Update state
state.unregister_node(node_id=request.node.node_id)
state.delete_node(node_id=request.node.node_id)
return DeleteNodeResponse()


Expand Down
20 changes: 10 additions & 10 deletions src/py/flwr/server/fleet/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def test_create_node() -> None:
create_node(request=request, state=state)

# Assert
state.register_node.assert_called_once()
state.unregister_node.assert_not_called()
state.create_node.assert_called_once()
state.delete_node.assert_not_called()
state.store_task_ins.assert_not_called()
state.get_task_ins.assert_not_called()
state.store_task_res.assert_not_called()
Expand All @@ -57,8 +57,8 @@ def test_delete_node_failure() -> None:
delete_node(request=request, state=state)

# Assert
state.register_node.assert_not_called()
state.unregister_node.assert_not_called()
state.create_node.assert_not_called()
state.delete_node.assert_not_called()
state.store_task_ins.assert_not_called()
state.get_task_ins.assert_not_called()
state.store_task_res.assert_not_called()
Expand All @@ -75,8 +75,8 @@ def test_delete_node_success() -> None:
delete_node(request=request, state=state)

# Assert
state.register_node.assert_not_called()
state.unregister_node.assert_called_once()
state.create_node.assert_not_called()
state.delete_node.assert_called_once()
state.store_task_ins.assert_not_called()
state.get_task_ins.assert_not_called()
state.store_task_res.assert_not_called()
Expand All @@ -93,8 +93,8 @@ def test_pull_task_ins() -> None:
pull_task_ins(request=request, state=state)

# Assert
state.register_node.assert_not_called()
state.unregister_node.assert_not_called()
state.create_node.assert_not_called()
state.delete_node.assert_not_called()
state.store_task_ins.assert_not_called()
state.get_task_ins.assert_called_once()
state.store_task_res.assert_not_called()
Expand All @@ -120,8 +120,8 @@ def test_push_task_res() -> None:
push_task_res(request=request, state=state)

# Assert
state.register_node.assert_not_called()
state.unregister_node.assert_not_called()
state.create_node.assert_not_called()
state.delete_node.assert_not_called()
state.store_task_ins.assert_not_called()
state.get_task_ins.assert_not_called()
state.store_task_res.assert_called_once()
Expand Down
27 changes: 16 additions & 11 deletions src/py/flwr/server/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""In-memory State implementation."""


import random
import os
from datetime import datetime, timedelta
from logging import ERROR
from typing import Dict, List, Optional, Set
Expand Down Expand Up @@ -182,16 +182,21 @@ def num_task_res(self) -> int:
"""
return len(self.task_res_store)

def register_node(self, node_id: int) -> None:
"""Register a client node."""
if node_id in self.node_ids:
raise ValueError(f"Node {node_id} is already registered")
self.node_ids.add(node_id)
def create_node(self) -> int:
"""Create, store in state, and return `node_id`."""
# Sample a random int64 as node_id
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)

def unregister_node(self, node_id: int) -> None:
"""Unregister a client node."""
if node_id not in self.node_ids:
raise ValueError(f"Node {node_id} is not registered")
self.node_ids.add(node_id)
return node_id
log(ERROR, "Unexpected node registration failure.")
return 0

def delete_node(self, node_id: int) -> None:
"""Delete a client node."""
if node_id not in self.node_ids:
raise ValueError(f"Node {node_id} not found")
self.node_ids.remove(node_id)

def get_nodes(self, workload_id: int) -> Set[int]:
Expand All @@ -208,8 +213,8 @@ def get_nodes(self, workload_id: int) -> Set[int]:

def create_workload(self) -> int:
"""Create one workload."""
# Sample random integer from 0 to 9223372036854775807
workload_id: int = random.randrange(9223372036854775808)
# Sample a random int64 as workload_id
workload_id: int = int.from_bytes(os.urandom(8), "little", signed=True)

if workload_id not in self.workload_ids:
self.workload_ids.add(workload_id)
Expand Down
24 changes: 16 additions & 8 deletions src/py/flwr/server/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""SQLite based implemenation of server state."""


import random
import os
import re
import sqlite3
from datetime import datetime, timedelta
Expand Down Expand Up @@ -469,13 +469,21 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None:

return None

def register_node(self, node_id: int) -> None:
"""Store `node_id` in state."""
def create_node(self) -> int:
"""Create, store in state, and return `node_id`."""
# Sample a random int64 as node_id
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)

query = "INSERT INTO node VALUES(:node_id);"
self.query(query, {"node_id": node_id})
try:
self.query(query, {"node_id": node_id})
except sqlite3.IntegrityError:
log(ERROR, "Unexpected node registration failure.")
return 0
return node_id

def unregister_node(self, node_id: int) -> None:
"""Remove `node_id` from state."""
def delete_node(self, node_id: int) -> None:
"""Delete a client node."""
query = "DELETE FROM node WHERE node_id = :node_id;"
self.query(query, {"node_id": node_id})

Expand All @@ -500,8 +508,8 @@ def get_nodes(self, workload_id: int) -> Set[int]:

def create_workload(self) -> int:
"""Create one workload and store it in state."""
# Sample random integer from 0 to 9223372036854775807
workload_id: int = random.randrange(9223372036854775808)
# Sample a random int64 as workload_id
workload_id: int = int.from_bytes(os.urandom(8), "little", signed=True)

# Check conflicts
query = "SELECT COUNT(*) FROM workload WHERE workload_id = ?;"
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/server/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,11 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None:
"""Delete all delivered TaskIns/TaskRes pairs."""

@abc.abstractmethod
def register_node(self, node_id: int) -> None:
"""Store `node_id` in state."""
def create_node(self) -> int:
"""Create, store in state, and return `node_id`."""

@abc.abstractmethod
def unregister_node(self, node_id: int) -> None:
def delete_node(self, node_id: int) -> None:
"""Remove `node_id` from state."""

@abc.abstractmethod
Expand Down
22 changes: 10 additions & 12 deletions src/py/flwr/server/state/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,32 +326,31 @@ def test_node_ids_initial_state(self) -> None:
# Assert
assert len(retrieved_node_ids) == 0

def test_register_node_and_get_nodes(self) -> None:
"""Test registering a client node."""
def test_create_node_and_get_nodes(self) -> None:
"""Test creating a client node."""
# Prepare
state: State = self.state_factory()
workload_id = state.create_workload()
node_ids = list(range(1, 11))
node_ids = []

# Execute
for i in node_ids:
state.register_node(i)
for _ in range(10):
node_ids.append(state.create_node())
retrieved_node_ids = state.get_nodes(workload_id)

# Assert
for i in retrieved_node_ids:
assert i in node_ids

def test_unregister_node(self) -> None:
"""Test unregistering a client node."""
def test_delete_node(self) -> None:
"""Test deleting a client node."""
# Prepare
state: State = self.state_factory()
workload_id = state.create_workload()
node_id = 2
node_id = state.create_node()

# Execute
state.register_node(node_id)
state.unregister_node(node_id)
state.delete_node(node_id)
retrieved_node_ids = state.get_nodes(workload_id)

# Assert
Expand All @@ -363,10 +362,9 @@ def test_get_nodes_invalid_workload_id(self) -> None:
state: State = self.state_factory()
state.create_workload()
invalid_workload_id = 61016
node_id = 2
state.create_node()

# Execute
state.register_node(node_id)
retrieved_node_ids = state.get_nodes(invalid_workload_id)

# Assert
Expand Down