Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng committed Nov 4, 2024
1 parent 674e698 commit 52187d0
Showing 1 changed file with 2 additions and 14 deletions.
16 changes: 2 additions & 14 deletions src/py/flwr/client/clientapp/clientappio_servicer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
PushClientAppOutputsResponse,
)
from flwr.proto.message_pb2 import Context as ProtoContext
from flwr.proto.run_pb2 import Run as ProtoRun
from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes

from .clientappio_servicer import ClientAppInputs, ClientAppIoServicer, ClientAppOutputs
Expand Down Expand Up @@ -71,19 +70,12 @@ def test_set_inputs(self) -> None:
state=self.maker.recordset(2, 2, 1),
run_config={"runconfig1": 6.1},
)
run = typing.Run(
run_id=1,
fab_id="lorem",
fab_version="ipsum",
fab_hash="dolor",
override_config=self.maker.user_config(),
)
fab = typing.Fab(
hash_str="abc123#$%",
content=b"\xf3\xf5\xf8\x98",
)

client_input = ClientAppInputs(message, context, run, fab, 1)
client_input = ClientAppInputs(message, context, fab, 1)
client_output = ClientAppOutputs(message, context)

# Execute and assert
Expand Down Expand Up @@ -157,23 +149,19 @@ def test_pull_clientapp_inputs(self) -> None:
mock_response = PullClientAppInputsResponse(
message=message_to_proto(mock_message),
context=ProtoContext(node_id=123),
run=ProtoRun(run_id=61016, fab_id="mock/mock", fab_version="v1.0.0"),
fab=fab_to_proto(mock_fab),
)
self.mock_stub.PullClientAppInputs.return_value = mock_response

# Execute
message, context, run, fab = pull_message(self.mock_stub, token=456)
message, context, fab = pull_message(self.mock_stub, token=456)

# Assert
self.mock_stub.PullClientAppInputs.assert_called_once()
self.assertEqual(len(message.content.parameters_records), 3)
self.assertEqual(len(message.content.metrics_records), 2)
self.assertEqual(len(message.content.configs_records), 1)
self.assertEqual(context.node_id, 123)
self.assertEqual(run.run_id, 61016)
self.assertEqual(run.fab_id, "mock/mock")
self.assertEqual(run.fab_version, "v1.0.0")
if fab:
self.assertEqual(fab.hash_str, mock_fab.hash_str)
self.assertEqual(fab.content, mock_fab.content)
Expand Down

0 comments on commit 52187d0

Please sign in to comment.