From 9a05a343b33c3f04ec0ba606956c1cb6172204de Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Wed, 11 Dec 2024 15:47:44 +0000 Subject: [PATCH] feat(framework) Update `ServerAppIoServicer` RPCs for `flwr stop` (#4629) --- src/proto/flwr/proto/serverappio.proto | 9 ++++- src/py/flwr/proto/serverappio_pb2.py | 36 +++++++++---------- src/py/flwr/proto/serverappio_pb2.pyi | 10 ++++-- src/py/flwr/proto/serverappio_pb2_grpc.py | 34 ++++++++++++++++++ src/py/flwr/proto/serverappio_pb2_grpc.pyi | 13 +++++++ src/py/flwr/server/driver/grpc_driver.py | 8 +++-- .../superlink/driver/serverappio_servicer.py | 18 ++++++++++ 7 files changed, 105 insertions(+), 23 deletions(-) diff --git a/src/proto/flwr/proto/serverappio.proto b/src/proto/flwr/proto/serverappio.proto index 3d8d3d6aa0d6..76352866a891 100644 --- a/src/proto/flwr/proto/serverappio.proto +++ b/src/proto/flwr/proto/serverappio.proto @@ -55,6 +55,9 @@ service ServerAppIo { rpc UpdateRunStatus(UpdateRunStatusRequest) returns (UpdateRunStatusResponse) {} + // Get the status of a given run + rpc GetRunStatus(GetRunStatusRequest) returns (GetRunStatusResponse) {} + // Push ServerApp logs rpc PushLogs(PushLogsRequest) returns (PushLogsResponse) {} } @@ -64,13 +67,17 @@ message GetNodesRequest { uint64 run_id = 1; } message GetNodesResponse { repeated Node nodes = 1; } // PushTaskIns messages -message PushTaskInsRequest { repeated TaskIns task_ins_list = 1; } +message PushTaskInsRequest { + repeated TaskIns task_ins_list = 1; + uint64 run_id = 2; +} message PushTaskInsResponse { repeated string task_ids = 2; } // PullTaskRes messages message PullTaskResRequest { Node node = 1; repeated string task_ids = 2; + uint64 run_id = 3; } message PullTaskResResponse { repeated TaskRes task_res_list = 1; } diff --git a/src/py/flwr/proto/serverappio_pb2.py b/src/py/flwr/proto/serverappio_pb2.py index 2bbd33b5c42b..f97d8362e8df 100644 --- a/src/py/flwr/proto/serverappio_pb2.py +++ b/src/py/flwr/proto/serverappio_pb2.py @@ -20,7 +20,7 @@ from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\xca\x06\n\x0bServerAppIo\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"P\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"V\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x04\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\x9f\x07\n\x0bServerAppIo\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -32,21 +32,21 @@ _globals['_GETNODESRESPONSE']._serialized_start=217 _globals['_GETNODESRESPONSE']._serialized_end=268 _globals['_PUSHTASKINSREQUEST']._serialized_start=270 - _globals['_PUSHTASKINSREQUEST']._serialized_end=334 - _globals['_PUSHTASKINSRESPONSE']._serialized_start=336 - _globals['_PUSHTASKINSRESPONSE']._serialized_end=375 - _globals['_PULLTASKRESREQUEST']._serialized_start=377 - _globals['_PULLTASKRESREQUEST']._serialized_end=447 - _globals['_PULLTASKRESRESPONSE']._serialized_start=449 - _globals['_PULLTASKRESRESPONSE']._serialized_end=514 - _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=516 - _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=544 - _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=546 - _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=673 - _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=675 - _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=758 - _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=760 - _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=790 - _globals['_SERVERAPPIO']._serialized_start=793 - _globals['_SERVERAPPIO']._serialized_end=1635 + _globals['_PUSHTASKINSREQUEST']._serialized_end=350 + _globals['_PUSHTASKINSRESPONSE']._serialized_start=352 + _globals['_PUSHTASKINSRESPONSE']._serialized_end=391 + _globals['_PULLTASKRESREQUEST']._serialized_start=393 + _globals['_PULLTASKRESREQUEST']._serialized_end=479 + _globals['_PULLTASKRESRESPONSE']._serialized_start=481 + _globals['_PULLTASKRESRESPONSE']._serialized_end=546 + _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=548 + _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=576 + _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=578 + _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=705 + _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=707 + _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=790 + _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=792 + _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=822 + _globals['_SERVERAPPIO']._serialized_start=825 + _globals['_SERVERAPPIO']._serialized_end=1752 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/serverappio_pb2.pyi b/src/py/flwr/proto/serverappio_pb2.pyi index 8191ec663442..38eff6456c03 100644 --- a/src/py/flwr/proto/serverappio_pb2.pyi +++ b/src/py/flwr/proto/serverappio_pb2.pyi @@ -44,13 +44,16 @@ class PushTaskInsRequest(google.protobuf.message.Message): """PushTaskIns messages""" DESCRIPTOR: google.protobuf.descriptor.Descriptor TASK_INS_LIST_FIELD_NUMBER: builtins.int + RUN_ID_FIELD_NUMBER: builtins.int @property def task_ins_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.task_pb2.TaskIns]: ... + run_id: builtins.int def __init__(self, *, task_ins_list: typing.Optional[typing.Iterable[flwr.proto.task_pb2.TaskIns]] = ..., + run_id: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["task_ins_list",b"task_ins_list"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id","task_ins_list",b"task_ins_list"]) -> None: ... global___PushTaskInsRequest = PushTaskInsRequest class PushTaskInsResponse(google.protobuf.message.Message): @@ -70,17 +73,20 @@ class PullTaskResRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor NODE_FIELD_NUMBER: builtins.int TASK_IDS_FIELD_NUMBER: builtins.int + RUN_ID_FIELD_NUMBER: builtins.int @property def node(self) -> flwr.proto.node_pb2.Node: ... @property def task_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... + run_id: builtins.int def __init__(self, *, node: typing.Optional[flwr.proto.node_pb2.Node] = ..., task_ids: typing.Optional[typing.Iterable[typing.Text]] = ..., + run_id: builtins.int = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["node",b"node","task_ids",b"task_ids"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["node",b"node","run_id",b"run_id","task_ids",b"task_ids"]) -> None: ... global___PullTaskResRequest = PullTaskResRequest class PullTaskResResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/proto/serverappio_pb2_grpc.py b/src/py/flwr/proto/serverappio_pb2_grpc.py index 1a7740db4271..ede888543883 100644 --- a/src/py/flwr/proto/serverappio_pb2_grpc.py +++ b/src/py/flwr/proto/serverappio_pb2_grpc.py @@ -62,6 +62,11 @@ def __init__(self, channel): request_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString, response_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString, ) + self.GetRunStatus = channel.unary_unary( + '/flwr.proto.ServerAppIo/GetRunStatus', + request_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString, + ) self.PushLogs = channel.unary_unary( '/flwr.proto.ServerAppIo/PushLogs', request_serializer=flwr_dot_proto_dot_log__pb2.PushLogsRequest.SerializeToString, @@ -135,6 +140,13 @@ def UpdateRunStatus(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def GetRunStatus(self, request, context): + """Get the status of a given run + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def PushLogs(self, request, context): """Push ServerApp logs """ @@ -190,6 +202,11 @@ def add_ServerAppIoServicer_to_server(servicer, server): request_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.FromString, response_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.SerializeToString, ), + 'GetRunStatus': grpc.unary_unary_rpc_method_handler( + servicer.GetRunStatus, + request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.FromString, + response_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.SerializeToString, + ), 'PushLogs': grpc.unary_unary_rpc_method_handler( servicer.PushLogs, request_deserializer=flwr_dot_proto_dot_log__pb2.PushLogsRequest.FromString, @@ -358,6 +375,23 @@ def UpdateRunStatus(request, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + @staticmethod + def GetRunStatus(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/flwr.proto.ServerAppIo/GetRunStatus', + flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString, + flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + @staticmethod def PushLogs(request, target, diff --git a/src/py/flwr/proto/serverappio_pb2_grpc.pyi b/src/py/flwr/proto/serverappio_pb2_grpc.pyi index aa2d29473ae8..f4e3fdc208a8 100644 --- a/src/py/flwr/proto/serverappio_pb2_grpc.pyi +++ b/src/py/flwr/proto/serverappio_pb2_grpc.pyi @@ -56,6 +56,11 @@ class ServerAppIoStub: flwr.proto.run_pb2.UpdateRunStatusResponse] """Update the status of a given run""" + GetRunStatus: grpc.UnaryUnaryMultiCallable[ + flwr.proto.run_pb2.GetRunStatusRequest, + flwr.proto.run_pb2.GetRunStatusResponse] + """Get the status of a given run""" + PushLogs: grpc.UnaryUnaryMultiCallable[ flwr.proto.log_pb2.PushLogsRequest, flwr.proto.log_pb2.PushLogsResponse] @@ -135,6 +140,14 @@ class ServerAppIoServicer(metaclass=abc.ABCMeta): """Update the status of a given run""" pass + @abc.abstractmethod + def GetRunStatus(self, + request: flwr.proto.run_pb2.GetRunStatusRequest, + context: grpc.ServicerContext, + ) -> flwr.proto.run_pb2.GetRunStatusResponse: + """Get the status of a given run""" + pass + @abc.abstractmethod def PushLogs(self, request: flwr.proto.log_pb2.PushLogsRequest, diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 05b7ce4be8bc..09318c32b704 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -203,7 +203,9 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: task_ins_list.append(taskins) # Call GrpcDriverStub method res: PushTaskInsResponse = self._stub.PushTaskIns( - PushTaskInsRequest(task_ins_list=task_ins_list) + PushTaskInsRequest( + task_ins_list=task_ins_list, run_id=cast(Run, self._run).run_id + ) ) return list(res.task_ids) @@ -215,7 +217,9 @@ def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: """ # Pull TaskRes res: PullTaskResResponse = self._stub.PullTaskRes( - PullTaskResRequest(node=self.node, task_ids=message_ids) + PullTaskResRequest( + node=self.node, task_ids=message_ids, run_id=cast(Run, self._run).run_id + ) ) # Convert TaskRes to Message msgs = [message_from_taskres(taskres) for taskres in res.task_res_list] diff --git a/src/py/flwr/server/superlink/driver/serverappio_servicer.py b/src/py/flwr/server/superlink/driver/serverappio_servicer.py index dddac1a93b1a..ca0b2ec0d8a5 100644 --- a/src/py/flwr/server/superlink/driver/serverappio_servicer.py +++ b/src/py/flwr/server/superlink/driver/serverappio_servicer.py @@ -32,6 +32,7 @@ fab_from_proto, fab_to_proto, run_status_from_proto, + run_status_to_proto, run_to_proto, user_config_from_proto, ) @@ -48,6 +49,8 @@ CreateRunResponse, GetRunRequest, GetRunResponse, + GetRunStatusRequest, + GetRunStatusResponse, UpdateRunStatusRequest, UpdateRunStatusResponse, ) @@ -284,6 +287,21 @@ def PushLogs( state.add_serverapp_log(request.run_id, merged_logs) return PushLogsResponse() + def GetRunStatus( + self, request: GetRunStatusRequest, context: grpc.ServicerContext + ) -> GetRunStatusResponse: + """Get the status of a run.""" + log(DEBUG, "ServerAppIoServicer.GetRunStatus") + state = self.state_factory.state() + + # Get run status from LinkState + run_statuses = state.get_run_status(set(request.run_ids)) + run_status_dict = { + run_id: run_status_to_proto(run_status) + for run_id, run_status in run_statuses.items() + } + return GetRunStatusResponse(run_status_dict=run_status_dict) + def _raise_if(validation_error: bool, detail: str) -> None: if validation_error: