Skip to content

Commit

Permalink
refactor(framework) Introduce run status checks in SimulationIo (#4729
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jafermarq authored Dec 17, 2024
1 parent 470ba4a commit e9c3653
Show file tree
Hide file tree
Showing 2 changed files with 222 additions and 0 deletions.
209 changes: 209 additions & 0 deletions src/py/flwr/server/superlink/simulation/simulation_servicer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SimulationIoServicer tests."""


import tempfile
import unittest

import grpc
from parameterized import parameterized

from flwr.common import ConfigsRecord, Context
from flwr.common.constant import SIMULATIONIO_API_DEFAULT_SERVER_ADDRESS, Status
from flwr.common.serde import context_to_proto, run_status_to_proto
from flwr.common.serde_test import RecordMaker
from flwr.common.typing import RunStatus
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
UpdateRunStatusRequest,
UpdateRunStatusResponse,
)
from flwr.proto.simulationio_pb2 import ( # pylint: disable=E0611
PushSimulationOutputsRequest,
PushSimulationOutputsResponse,
)
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.linkstate.linkstate_factory import LinkStateFactory
from flwr.server.superlink.simulation.simulationio_grpc import run_simulationio_api_grpc
from flwr.server.superlink.utils import _STATUS_TO_MSG


class TestSimulationIoServicer(unittest.TestCase): # pylint: disable=R0902
"""SimulationIoServicer tests for allowed RunStatuses."""

def setUp(self) -> None:
"""Initialize mock stub and server interceptor."""
# Create a temporary directory
self.temp_dir = tempfile.TemporaryDirectory() # pylint: disable=R1732
self.addCleanup(self.temp_dir.cleanup) # Ensures cleanup after test

state_factory = LinkStateFactory(":flwr-in-memory-state:")
self.state = state_factory.state()
ffs_factory = FfsFactory(self.temp_dir.name)
self.ffs = ffs_factory.ffs()

self.status_to_msg = _STATUS_TO_MSG

self._server: grpc.Server = run_simulationio_api_grpc(
SIMULATIONIO_API_DEFAULT_SERVER_ADDRESS,
state_factory,
ffs_factory,
None,
)

self._channel = grpc.insecure_channel("localhost:9096")
self._push_simulation_outputs = self._channel.unary_unary(
"/flwr.proto.SimulationIo/PushSimulationOutputs",
request_serializer=PushSimulationOutputsRequest.SerializeToString,
response_deserializer=PushSimulationOutputsResponse.FromString,
)
self._update_run_status = self._channel.unary_unary(
"/flwr.proto.SimulationIo/UpdateRunStatus",
request_serializer=UpdateRunStatusRequest.SerializeToString,
response_deserializer=UpdateRunStatusResponse.FromString,
)

def tearDown(self) -> None:
"""Clean up grpc server."""
self._server.stop(None)

def _transition_run_status(self, run_id: int, num_transitions: int) -> None:
if num_transitions > 0:
_ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", ""))
if num_transitions > 1:
_ = self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", ""))
if num_transitions > 2:
_ = self.state.update_run_status(run_id, RunStatus(Status.FINISHED, "", ""))

def test_push_simulation_outputs_successful_if_running(self) -> None:
"""Test `PushSimulationOutputs` success."""
# Prepare
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())

maker = RecordMaker()
context = Context(
run_id=run_id,
node_id=0,
node_config=maker.user_config(),
state=maker.recordset(1, 1, 1),
run_config=maker.user_config(),
)

# Transition status to running. PushTaskRes is only allowed in running status.
self._transition_run_status(run_id, 2)
request = PushSimulationOutputsRequest(
run_id=run_id, context=context_to_proto(context)
)

# Execute
response, call = self._push_simulation_outputs.with_call(request=request)

# Assert
assert isinstance(response, PushSimulationOutputsResponse)
assert grpc.StatusCode.OK == call.code()

def _assert_push_simulation_outputs_not_allowed(
self, run_id: int, context: Context
) -> None:
"""Assert `PushSimulationOutputs` not allowed."""
run_status = self.state.get_run_status({run_id})[run_id]
request = PushSimulationOutputsRequest(
run_id=run_id, context=context_to_proto(context)
)

with self.assertRaises(grpc.RpcError) as e:
self._push_simulation_outputs.with_call(request=request)
assert e.exception.code() == grpc.StatusCode.PERMISSION_DENIED
assert e.exception.details() == self.status_to_msg[run_status.status]

@parameterized.expand(
[
(0,), # Test not successful if RunStatus is pending.
(1,), # Test not successful if RunStatus is starting.
(3,), # Test not successful if RunStatus is finished.
]
) # type: ignore
def test_push_simulation_outputs_not_successful_if_not_running(
self, num_transitions: int
) -> None:
"""Test `PushSimulationOutputs` not successful if RunStatus is not running."""
# Prepare
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())

maker = RecordMaker()
context = Context(
run_id=run_id,
node_id=0,
node_config=maker.user_config(),
state=maker.recordset(1, 1, 1),
run_config=maker.user_config(),
)

self._transition_run_status(run_id, num_transitions)

# Execute & Assert
self._assert_push_simulation_outputs_not_allowed(run_id, context)

@parameterized.expand(
[
(0,), # Test successful if RunStatus is pending.
(1,), # Test successful if RunStatus is starting.
(2,), # Test successful if RunStatus is running.
]
) # type: ignore
def test_update_run_status_successful_if_not_finished(
self, num_transitions: int
) -> None:
"""Test `UpdateRunStatus` success."""
# Prepare
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())
_ = self.state.get_run_status({run_id})[run_id]
next_run_status = RunStatus(Status.STARTING, "", "")

if num_transitions > 0:
_ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", ""))
next_run_status = RunStatus(Status.RUNNING, "", "")
if num_transitions > 1:
_ = self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", ""))
next_run_status = RunStatus(Status.FINISHED, "", "")

request = UpdateRunStatusRequest(
run_id=run_id, run_status=run_status_to_proto(next_run_status)
)

# Execute
response, call = self._update_run_status.with_call(request=request)

# Assert
assert isinstance(response, UpdateRunStatusResponse)
assert grpc.StatusCode.OK == call.code()

def test_update_run_status_not_successful_if_finished(self) -> None:
"""Test `UpdateRunStatus` not successful."""
# Prepare
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())
_ = self.state.get_run_status({run_id})[run_id]
_ = self.state.update_run_status(run_id, RunStatus(Status.FINISHED, "", ""))
run_status = self.state.get_run_status({run_id})[run_id]
next_run_status = RunStatus(Status.FINISHED, "", "")

request = UpdateRunStatusRequest(
run_id=run_id, run_status=run_status_to_proto(next_run_status)
)

with self.assertRaises(grpc.RpcError) as e:
self._update_run_status.with_call(request=request)
assert e.exception.code() == grpc.StatusCode.PERMISSION_DENIED
assert e.exception.details() == self.status_to_msg[run_status.status]
13 changes: 13 additions & 0 deletions src/py/flwr/server/superlink/simulation/simulationio_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
)
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.linkstate import LinkStateFactory
from flwr.server.superlink.utils import abort_if


class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
Expand Down Expand Up @@ -110,6 +111,15 @@ def PushSimulationOutputs(
"""Push Simulation process outputs."""
log(DEBUG, "SimultionIoServicer.PushSimulationOutputs")
state = self.state_factory.state()

# Abort if the run is not running
abort_if(
request.run_id,
[Status.PENDING, Status.STARTING, Status.FINISHED],
state,
context,
)

state.set_serverapp_context(request.run_id, context_from_proto(request.context))
return PushSimulationOutputsResponse()

Expand All @@ -120,6 +130,9 @@ def UpdateRunStatus(
log(DEBUG, "SimultionIoServicer.UpdateRunStatus")
state = self.state_factory.state()

# Abort if the run is finished
abort_if(request.run_id, [Status.FINISHED], state, context)

# Update the run status
state.update_run_status(
run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
Expand Down

0 comments on commit e9c3653

Please sign in to comment.