Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(framework:skip) Remove Run from flwr-clientapp #4426

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/proto/flwr/proto/clientappio.proto
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ syntax = "proto3";
package flwr.proto;

import "flwr/proto/fab.proto";
import "flwr/proto/run.proto";
import "flwr/proto/message.proto";

service ClientAppIo {
// Get token
rpc GetToken(GetTokenRequest) returns (GetTokenResponse) {}

// Get Message, Context, and Run
// Get Message, Context, and Fab
rpc PullClientAppInputs(PullClientAppInputsRequest)
returns (PullClientAppInputsResponse) {}

Expand All @@ -51,8 +50,7 @@ message PullClientAppInputsRequest { uint64 token = 1; }
message PullClientAppInputsResponse {
Message message = 1;
Context context = 2;
Run run = 3;
Fab fab = 4;
Fab fab = 3;
}

message PushClientAppOutputsRequest {
Expand Down
1 change: 0 additions & 1 deletion src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,6 @@ def _on_backoff(retry_state: RetryState) -> None:
clientapp_input=ClientAppInputs(
message=message,
context=context,
run=run,
fab=fab,
token=token,
),
Expand Down
16 changes: 8 additions & 8 deletions src/py/flwr/client/clientapp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import grpc

from flwr.cli.config_utils import get_fab_metadata
from flwr.cli.install import install_from_fab
from flwr.client.client_app import ClientApp, LoadClientAppError
from flwr.common import Context, Message
Expand All @@ -34,9 +35,8 @@
fab_from_proto,
message_from_proto,
message_to_proto,
run_from_proto,
)
from flwr.common.typing import Fab, Run
from flwr.common.typing import Fab

# pylint: disable=E0611
from flwr.proto.clientappio_pb2 import (
Expand Down Expand Up @@ -115,13 +115,14 @@ def run_clientapp( # pylint: disable=R0914
token = get_token(stub)
time.sleep(1)

# Pull Message, Context, Run and (optional) FAB from SuperNode
message, context, run, fab = pull_message(stub=stub, token=token)
# Pull Message, Context, and (optional) FAB from SuperNode
message, context, fab = pull_message(stub=stub, token=token)

# Install FAB, if provided
if fab:
log(DEBUG, "Flower ClientApp starts FAB installation.")
install_from_fab(fab.content, flwr_dir=None, skip_prompt=True)
fab_id, fab_version = get_fab_metadata(fab.content)

load_client_app_fn = get_load_client_app_fn(
default_app_ref="",
Expand All @@ -133,7 +134,7 @@ def run_clientapp( # pylint: disable=R0914
try:
# Load ClientApp
client_app: ClientApp = load_client_app_fn(
run.fab_id, run.fab_version, fab.hash_str if fab else ""
fab_id, fab_version, fab.hash_str if fab else ""
)

# Execute ClientApp
Expand Down Expand Up @@ -197,7 +198,7 @@ def get_token(stub: grpc.Channel) -> Optional[int]:

def pull_message(
stub: grpc.Channel, token: int
) -> tuple[Message, Context, Run, Optional[Fab]]:
) -> tuple[Message, Context, Optional[Fab]]:
"""Pull message from SuperNode to ClientApp."""
log(INFO, "Pulling ClientAppInputs for token %s", token)
try:
Expand All @@ -206,9 +207,8 @@ def pull_message(
)
message = message_from_proto(res.message)
context = context_from_proto(res.context)
run = run_from_proto(res.run)
fab = fab_from_proto(res.fab) if res.fab else None
return message, context, run, fab
return message, context, fab
except grpc.RpcError as e:
log(ERROR, "[PullClientAppInputs] gRPC error occurred: %s", str(e))
raise e
Expand Down
7 changes: 2 additions & 5 deletions src/py/flwr/client/clientapp/clientappio_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@
fab_to_proto,
message_from_proto,
message_to_proto,
run_to_proto,
)
from flwr.common.typing import Fab, Run
from flwr.common.typing import Fab

# pylint: disable=E0611
from flwr.proto import clientappio_pb2_grpc
Expand All @@ -52,7 +51,6 @@ class ClientAppInputs:

message: Message
context: Context
run: Run
fab: Optional[Fab]
token: int

Expand Down Expand Up @@ -106,7 +104,7 @@ def GetToken(
def PullClientAppInputs(
self, request: PullClientAppInputsRequest, context: grpc.ServicerContext
) -> PullClientAppInputsResponse:
"""Pull Message, Context, and Run."""
"""Pull Message, Context, and Fab."""
log(DEBUG, "ClientAppIo.PullClientAppInputs")

# Fail if no ClientAppInputs are available
Expand Down Expand Up @@ -137,7 +135,6 @@ def PullClientAppInputs(
return PullClientAppInputsResponse(
message=message_to_proto(clientapp_input.message),
context=context_to_proto(clientapp_input.context),
run=run_to_proto(clientapp_input.run),
fab=fab_to_proto(clientapp_input.fab) if clientapp_input.fab else None,
)

Expand Down
16 changes: 2 additions & 14 deletions src/py/flwr/client/clientapp/clientappio_servicer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
PushClientAppOutputsResponse,
)
from flwr.proto.message_pb2 import Context as ProtoContext
from flwr.proto.run_pb2 import Run as ProtoRun
from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes

from .clientappio_servicer import ClientAppInputs, ClientAppIoServicer, ClientAppOutputs
Expand Down Expand Up @@ -71,19 +70,12 @@ def test_set_inputs(self) -> None:
state=self.maker.recordset(2, 2, 1),
run_config={"runconfig1": 6.1},
)
run = typing.Run(
run_id=1,
fab_id="lorem",
fab_version="ipsum",
fab_hash="dolor",
override_config=self.maker.user_config(),
)
fab = typing.Fab(
hash_str="abc123#$%",
content=b"\xf3\xf5\xf8\x98",
)

client_input = ClientAppInputs(message, context, run, fab, 1)
client_input = ClientAppInputs(message, context, fab, 1)
client_output = ClientAppOutputs(message, context)

# Execute and assert
Expand Down Expand Up @@ -157,23 +149,19 @@ def test_pull_clientapp_inputs(self) -> None:
mock_response = PullClientAppInputsResponse(
message=message_to_proto(mock_message),
context=ProtoContext(node_id=123),
run=ProtoRun(run_id=61016, fab_id="mock/mock", fab_version="v1.0.0"),
fab=fab_to_proto(mock_fab),
)
self.mock_stub.PullClientAppInputs.return_value = mock_response

# Execute
message, context, run, fab = pull_message(self.mock_stub, token=456)
message, context, fab = pull_message(self.mock_stub, token=456)

# Assert
self.mock_stub.PullClientAppInputs.assert_called_once()
self.assertEqual(len(message.content.parameters_records), 3)
self.assertEqual(len(message.content.metrics_records), 2)
self.assertEqual(len(message.content.configs_records), 1)
self.assertEqual(context.node_id, 123)
self.assertEqual(run.run_id, 61016)
self.assertEqual(run.fab_id, "mock/mock")
self.assertEqual(run.fab_version, "v1.0.0")
if fab:
self.assertEqual(fab.hash_str, mock_fab.hash_str)
self.assertEqual(fab.content, mock_fab.content)
Expand Down
39 changes: 19 additions & 20 deletions src/py/flwr/proto/clientappio_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 2 additions & 7 deletions src/py/flwr/proto/clientappio_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ isort:skip_file
import builtins
import flwr.proto.fab_pb2
import flwr.proto.message_pb2
import flwr.proto.run_pb2
import google.protobuf.descriptor
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
Expand Down Expand Up @@ -77,25 +76,21 @@ class PullClientAppInputsResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
MESSAGE_FIELD_NUMBER: builtins.int
CONTEXT_FIELD_NUMBER: builtins.int
RUN_FIELD_NUMBER: builtins.int
FAB_FIELD_NUMBER: builtins.int
@property
def message(self) -> flwr.proto.message_pb2.Message: ...
@property
def context(self) -> flwr.proto.message_pb2.Context: ...
@property
def run(self) -> flwr.proto.run_pb2.Run: ...
@property
def fab(self) -> flwr.proto.fab_pb2.Fab: ...
def __init__(self,
*,
message: typing.Optional[flwr.proto.message_pb2.Message] = ...,
context: typing.Optional[flwr.proto.message_pb2.Context] = ...,
run: typing.Optional[flwr.proto.run_pb2.Run] = ...,
fab: typing.Optional[flwr.proto.fab_pb2.Fab] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["context",b"context","fab",b"fab","message",b"message","run",b"run"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["context",b"context","fab",b"fab","message",b"message","run",b"run"]) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["context",b"context","fab",b"fab","message",b"message"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["context",b"context","fab",b"fab","message",b"message"]) -> None: ...
global___PullClientAppInputsResponse = PullClientAppInputsResponse

class PushClientAppOutputsRequest(google.protobuf.message.Message):
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/proto/clientappio_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def GetToken(self, request, context):
raise NotImplementedError('Method not implemented!')

def PullClientAppInputs(self, request, context):
"""Get Message, Context, and Run
"""Get Message, Context, and Fab
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/proto/clientappio_pb2_grpc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class ClientAppIoStub:
PullClientAppInputs: grpc.UnaryUnaryMultiCallable[
flwr.proto.clientappio_pb2.PullClientAppInputsRequest,
flwr.proto.clientappio_pb2.PullClientAppInputsResponse]
"""Get Message, Context, and Run"""
"""Get Message, Context, and Fab"""

PushClientAppOutputs: grpc.UnaryUnaryMultiCallable[
flwr.proto.clientappio_pb2.PushClientAppOutputsRequest,
Expand All @@ -38,7 +38,7 @@ class ClientAppIoServicer(metaclass=abc.ABCMeta):
request: flwr.proto.clientappio_pb2.PullClientAppInputsRequest,
context: grpc.ServicerContext,
) -> flwr.proto.clientappio_pb2.PullClientAppInputsResponse:
"""Get Message, Context, and Run"""
"""Get Message, Context, and Fab"""
pass

@abc.abstractmethod
Expand Down