Skip to content

Commit

Permalink
feat(framework) Add checks to ServerAppIoServicer to allow context …
Browse files Browse the repository at this point in the history
…abort via `flwr stop` (#4646)

Co-authored-by: Heng Pan <[email protected]>
Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
3 people authored Dec 12, 2024
1 parent 3e6c0b4 commit 3966e8d
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 6 deletions.
56 changes: 50 additions & 6 deletions src/py/flwr/server/superlink/driver/serverappio_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from flwr.server.superlink.ffs.ffs import Ffs
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
from flwr.server.superlink.utils import abort_if
from flwr.server.utils.validator import validate_task_ins_or_res


Expand All @@ -88,7 +89,18 @@ def GetNodes(
) -> GetNodesResponse:
"""Get available nodes."""
log(DEBUG, "ServerAppIoServicer.GetNodes")

# Init state
state: LinkState = self.state_factory.state()

# Abort if the run is not running
abort_if(
request.run_id,
[Status.PENDING, Status.STARTING, Status.FINISHED],
state,
context,
)

all_ids: set[int] = state.get_nodes(request.run_id)
nodes: list[Node] = [
Node(node_id=node_id, anonymous=False) for node_id in all_ids
Expand Down Expand Up @@ -126,6 +138,17 @@ def PushTaskIns(
"""Push a set of TaskIns."""
log(DEBUG, "ServerAppIoServicer.PushTaskIns")

# Init state
state: LinkState = self.state_factory.state()

# Abort if the run is not running
abort_if(
request.run_id,
[Status.PENDING, Status.STARTING, Status.FINISHED],
state,
context,
)

# Set pushed_at (timestamp in seconds)
pushed_at = time.time()
for task_ins in request.task_ins_list:
Expand All @@ -137,9 +160,6 @@ def PushTaskIns(
validation_errors = validate_task_ins_or_res(task_ins)
_raise_if(bool(validation_errors), ", ".join(validation_errors))

# Init state
state: LinkState = self.state_factory.state()

# Store each TaskIns
task_ids: list[Optional[UUID]] = []
for task_ins in request.task_ins_list:
Expand All @@ -156,12 +176,20 @@ def PullTaskRes(
"""Pull a set of TaskRes."""
log(DEBUG, "ServerAppIoServicer.PullTaskRes")

# Convert each task_id str to UUID
task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}

# Init state
state: LinkState = self.state_factory.state()

# Abort if the run is not running
abort_if(
request.run_id,
[Status.PENDING, Status.STARTING, Status.FINISHED],
state,
context,
)

# Convert each task_id str to UUID
task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}

# Register callback
def on_rpc_done() -> None:
log(
Expand Down Expand Up @@ -258,7 +286,18 @@ def PushServerAppOutputs(
) -> PushServerAppOutputsResponse:
"""Push ServerApp process outputs."""
log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")

# Init state
state = self.state_factory.state()

# Abort if the run is not running
abort_if(
request.run_id,
[Status.PENDING, Status.STARTING, Status.FINISHED],
state,
context,
)

state.set_serverapp_context(request.run_id, context_from_proto(request.context))
return PushServerAppOutputsResponse()

Expand All @@ -267,8 +306,13 @@ def UpdateRunStatus(
) -> UpdateRunStatusResponse:
"""Update the status of a run."""
log(DEBUG, "ServerAppIoServicer.UpdateRunStatus")

# Init state
state = self.state_factory.state()

# Abort if the run is finished
abort_if(request.run_id, [Status.FINISHED], state, context)

# Update the run status
state.update_run_status(
run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
Expand Down
65 changes: 65 additions & 0 deletions src/py/flwr/server/superlink/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SuperLink utilities."""


from typing import Union

import grpc

from flwr.common.constant import Status, SubStatus
from flwr.common.typing import RunStatus
from flwr.server.superlink.linkstate import LinkState

_STATUS_TO_MSG = {
Status.PENDING: "Run is pending.",
Status.STARTING: "Run is starting.",
Status.RUNNING: "Run is running.",
Status.FINISHED: "Run is finished.",
}


def check_abort(
run_id: int,
abort_status_list: list[str],
state: LinkState,
) -> Union[str, None]:
"""Check if the status of the provided `run_id` is in `abort_status_list`."""
run_status: RunStatus = state.get_run_status({run_id})[run_id]

if run_status.status in abort_status_list:
msg = _STATUS_TO_MSG[run_status.status]
if run_status.sub_status == SubStatus.STOPPED:
msg += " Stopped by user."
return msg

return None


def abort_grpc_context(msg: Union[str, None], context: grpc.ServicerContext) -> None:
"""Abort context with statuscode PERMISSION_DENIED if `msg` is not None."""
if msg is not None:
context.abort(grpc.StatusCode.PERMISSION_DENIED, msg)


def abort_if(
run_id: int,
abort_status_list: list[str],
state: LinkState,
context: grpc.ServicerContext,
) -> None:
"""Abort context if status of the provided `run_id` is in `abort_status_list`."""
msg = check_abort(run_id, abort_status_list, state)
abort_grpc_context(msg, context)

0 comments on commit 3966e8d

Please sign in to comment.