Skip to content

Commit

Permalink
feat(framework) Update ServerAppIoServicer RPCs for flwr stop (#4629
Browse files Browse the repository at this point in the history
)
  • Loading branch information
chongshenng authored Dec 11, 2024
1 parent a4e579a commit 9a05a34
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 23 deletions.
9 changes: 8 additions & 1 deletion src/proto/flwr/proto/serverappio.proto
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ service ServerAppIo {
rpc UpdateRunStatus(UpdateRunStatusRequest)
returns (UpdateRunStatusResponse) {}

// Get the status of a given run
rpc GetRunStatus(GetRunStatusRequest) returns (GetRunStatusResponse) {}

// Push ServerApp logs
rpc PushLogs(PushLogsRequest) returns (PushLogsResponse) {}
}
Expand All @@ -64,13 +67,17 @@ message GetNodesRequest { uint64 run_id = 1; }
message GetNodesResponse { repeated Node nodes = 1; }

// PushTaskIns messages
message PushTaskInsRequest { repeated TaskIns task_ins_list = 1; }
message PushTaskInsRequest {
repeated TaskIns task_ins_list = 1;
uint64 run_id = 2;
}
message PushTaskInsResponse { repeated string task_ids = 2; }

// PullTaskRes messages
message PullTaskResRequest {
Node node = 1;
repeated string task_ids = 2;
uint64 run_id = 3;
}
message PullTaskResResponse { repeated TaskRes task_res_list = 1; }

Expand Down
36 changes: 18 additions & 18 deletions src/py/flwr/proto/serverappio_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 8 additions & 2 deletions src/py/flwr/proto/serverappio_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,16 @@ class PushTaskInsRequest(google.protobuf.message.Message):
"""PushTaskIns messages"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TASK_INS_LIST_FIELD_NUMBER: builtins.int
RUN_ID_FIELD_NUMBER: builtins.int
@property
def task_ins_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.task_pb2.TaskIns]: ...
run_id: builtins.int
def __init__(self,
*,
task_ins_list: typing.Optional[typing.Iterable[flwr.proto.task_pb2.TaskIns]] = ...,
run_id: builtins.int = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["task_ins_list",b"task_ins_list"]) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id","task_ins_list",b"task_ins_list"]) -> None: ...
global___PushTaskInsRequest = PushTaskInsRequest

class PushTaskInsResponse(google.protobuf.message.Message):
Expand All @@ -70,17 +73,20 @@ class PullTaskResRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
NODE_FIELD_NUMBER: builtins.int
TASK_IDS_FIELD_NUMBER: builtins.int
RUN_ID_FIELD_NUMBER: builtins.int
@property
def node(self) -> flwr.proto.node_pb2.Node: ...
@property
def task_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ...
run_id: builtins.int
def __init__(self,
*,
node: typing.Optional[flwr.proto.node_pb2.Node] = ...,
task_ids: typing.Optional[typing.Iterable[typing.Text]] = ...,
run_id: builtins.int = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["node",b"node","task_ids",b"task_ids"]) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["node",b"node","run_id",b"run_id","task_ids",b"task_ids"]) -> None: ...
global___PullTaskResRequest = PullTaskResRequest

class PullTaskResResponse(google.protobuf.message.Message):
Expand Down
34 changes: 34 additions & 0 deletions src/py/flwr/proto/serverappio_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ def __init__(self, channel):
request_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString,
)
self.GetRunStatus = channel.unary_unary(
'/flwr.proto.ServerAppIo/GetRunStatus',
request_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
)
self.PushLogs = channel.unary_unary(
'/flwr.proto.ServerAppIo/PushLogs',
request_serializer=flwr_dot_proto_dot_log__pb2.PushLogsRequest.SerializeToString,
Expand Down Expand Up @@ -135,6 +140,13 @@ def UpdateRunStatus(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def GetRunStatus(self, request, context):
"""Get the status of a given run
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def PushLogs(self, request, context):
"""Push ServerApp logs
"""
Expand Down Expand Up @@ -190,6 +202,11 @@ def add_ServerAppIoServicer_to_server(servicer, server):
request_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.FromString,
response_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.SerializeToString,
),
'GetRunStatus': grpc.unary_unary_rpc_method_handler(
servicer.GetRunStatus,
request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.FromString,
response_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.SerializeToString,
),
'PushLogs': grpc.unary_unary_rpc_method_handler(
servicer.PushLogs,
request_deserializer=flwr_dot_proto_dot_log__pb2.PushLogsRequest.FromString,
Expand Down Expand Up @@ -358,6 +375,23 @@ def UpdateRunStatus(request,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def GetRunStatus(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/flwr.proto.ServerAppIo/GetRunStatus',
flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def PushLogs(request,
target,
Expand Down
13 changes: 13 additions & 0 deletions src/py/flwr/proto/serverappio_pb2_grpc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ class ServerAppIoStub:
flwr.proto.run_pb2.UpdateRunStatusResponse]
"""Update the status of a given run"""

GetRunStatus: grpc.UnaryUnaryMultiCallable[
flwr.proto.run_pb2.GetRunStatusRequest,
flwr.proto.run_pb2.GetRunStatusResponse]
"""Get the status of a given run"""

PushLogs: grpc.UnaryUnaryMultiCallable[
flwr.proto.log_pb2.PushLogsRequest,
flwr.proto.log_pb2.PushLogsResponse]
Expand Down Expand Up @@ -135,6 +140,14 @@ class ServerAppIoServicer(metaclass=abc.ABCMeta):
"""Update the status of a given run"""
pass

@abc.abstractmethod
def GetRunStatus(self,
request: flwr.proto.run_pb2.GetRunStatusRequest,
context: grpc.ServicerContext,
) -> flwr.proto.run_pb2.GetRunStatusResponse:
"""Get the status of a given run"""
pass

@abc.abstractmethod
def PushLogs(self,
request: flwr.proto.log_pb2.PushLogsRequest,
Expand Down
8 changes: 6 additions & 2 deletions src/py/flwr/server/driver/grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
task_ins_list.append(taskins)
# Call GrpcDriverStub method
res: PushTaskInsResponse = self._stub.PushTaskIns(
PushTaskInsRequest(task_ins_list=task_ins_list)
PushTaskInsRequest(
task_ins_list=task_ins_list, run_id=cast(Run, self._run).run_id
)
)
return list(res.task_ids)

Expand All @@ -215,7 +217,9 @@ def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
"""
# Pull TaskRes
res: PullTaskResResponse = self._stub.PullTaskRes(
PullTaskResRequest(node=self.node, task_ids=message_ids)
PullTaskResRequest(
node=self.node, task_ids=message_ids, run_id=cast(Run, self._run).run_id
)
)
# Convert TaskRes to Message
msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
Expand Down
18 changes: 18 additions & 0 deletions src/py/flwr/server/superlink/driver/serverappio_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
fab_from_proto,
fab_to_proto,
run_status_from_proto,
run_status_to_proto,
run_to_proto,
user_config_from_proto,
)
Expand All @@ -48,6 +49,8 @@
CreateRunResponse,
GetRunRequest,
GetRunResponse,
GetRunStatusRequest,
GetRunStatusResponse,
UpdateRunStatusRequest,
UpdateRunStatusResponse,
)
Expand Down Expand Up @@ -284,6 +287,21 @@ def PushLogs(
state.add_serverapp_log(request.run_id, merged_logs)
return PushLogsResponse()

def GetRunStatus(
self, request: GetRunStatusRequest, context: grpc.ServicerContext
) -> GetRunStatusResponse:
"""Get the status of a run."""
log(DEBUG, "ServerAppIoServicer.GetRunStatus")
state = self.state_factory.state()

# Get run status from LinkState
run_statuses = state.get_run_status(set(request.run_ids))
run_status_dict = {
run_id: run_status_to_proto(run_status)
for run_id, run_status in run_statuses.items()
}
return GetRunStatusResponse(run_status_dict=run_status_dict)


def _raise_if(validation_error: bool, detail: str) -> None:
if validation_error:
Expand Down

0 comments on commit 9a05a34

Please sign in to comment.