Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add *Record ProtoBuf messages and corresponding serde functions. #2831

Merged
merged 16 commits into from
Jan 22, 2024
70 changes: 70 additions & 0 deletions src/proto/flwr/proto/recordset.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright 2024 Flower Labs GmbH. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ==============================================================================

syntax = "proto3";

package flwr.proto;

message DoubleList { repeated double vals = 1; }
message Sint64List { repeated sint64 vals = 1; }
message BoolList { repeated bool vals = 1; }
message StringList { repeated string vals = 1; }
message BytesList { repeated bytes vals = 1; }

message Array {
string dtype = 1;
repeated int32 shape = 2;
string stype = 3;
bytes data = 4;
}

message MetricsRecordValue {
oneof value {
// Single element
double double = 1;
sint64 sint64 = 2;

// List types
DoubleList double_list = 21;
Sint64List sint64_list = 22;
}
}

message ConfigsRecordValue {
oneof value {
// Single element
double double = 1;
sint64 sint64 = 2;
bool bool = 3;
string string = 4;
bytes bytes = 5;

// List types
DoubleList double_list = 21;
Sint64List sint64_list = 22;
BoolList bool_list = 23;
StringList string_list = 24;
BytesList bytes_list = 25;
}
}

message ParametersRecord {
repeated string data_keys = 1;
repeated Array data_values = 2;
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
}

message MetricsRecord { map<string, MetricsRecordValue> data = 1; }

message ConfigsRecord { map<string, ConfigsRecordValue> data = 1; }
7 changes: 1 addition & 6 deletions src/proto/flwr/proto/task.proto
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ syntax = "proto3";
package flwr.proto;

import "flwr/proto/node.proto";
import "flwr/proto/recordset.proto";
import "flwr/proto/transport.proto";

message Task {
Expand Down Expand Up @@ -49,12 +50,6 @@ message TaskRes {
}

message Value {
message DoubleList { repeated double vals = 1; }
message Sint64List { repeated sint64 vals = 1; }
message BoolList { repeated bool vals = 1; }
message StringList { repeated string vals = 1; }
message BytesList { repeated bytes vals = 1; }

oneof value {
// Single element
double double = 1;
Expand Down
180 changes: 163 additions & 17 deletions src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,22 @@
"""ProtoBuf serialization and deserialization."""


from typing import Any, Dict, List, MutableMapping, cast

from flwr.proto.task_pb2 import Value # pylint: disable=E0611
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar, cast

from google.protobuf.message import Message

# pylint: disable=E0611
from flwr.proto.recordset_pb2 import Array as ProtoArray
from flwr.proto.recordset_pb2 import BoolList, BytesList
from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
from flwr.proto.recordset_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue
from flwr.proto.recordset_pb2 import DoubleList
from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord
from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordValue
from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
from flwr.proto.recordset_pb2 import Sint64List, StringList
from flwr.proto.task_pb2 import Value
from flwr.proto.transport_pb2 import (
ClientMessage,
Code,
Parameters,
Expand All @@ -28,7 +40,11 @@
Status,
)

# pylint: enable=E0611
from . import typing
from .configsrecord import ConfigsRecord
from .metricsrecord import MetricsRecord
from .parametersrecord import Array, ParametersRecord

# === ServerMessage message ===

Expand Down Expand Up @@ -493,7 +509,7 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:
# === Value messages ===


_python_type_to_field_name = {
_type_to_field = {
float: "double",
int: "sint64",
bool: "bool",
Expand All @@ -502,22 +518,20 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:
}


_python_list_type_to_message_and_field_name = {
float: (Value.DoubleList, "double_list"),
int: (Value.Sint64List, "sint64_list"),
bool: (Value.BoolList, "bool_list"),
str: (Value.StringList, "string_list"),
bytes: (Value.BytesList, "bytes_list"),
_list_type_to_class_and_field = {
float: (DoubleList, "double_list"),
int: (Sint64List, "sint64_list"),
bool: (BoolList, "bool_list"),
str: (StringList, "string_list"),
bytes: (BytesList, "bytes_list"),
}


def _check_value(value: typing.Value) -> None:
if isinstance(value, tuple(_python_type_to_field_name.keys())):
if isinstance(value, tuple(_type_to_field.keys())):
return
if isinstance(value, list):
if len(value) > 0 and isinstance(
value[0], tuple(_python_type_to_field_name.keys())
):
if len(value) > 0 and isinstance(value[0], tuple(_type_to_field.keys())):
data_type = type(value[0])
for element in value:
if isinstance(element, data_type):
Expand All @@ -539,12 +553,12 @@ def value_to_proto(value: typing.Value) -> Value:

arg = {}
if isinstance(value, list):
msg_class, field_name = _python_list_type_to_message_and_field_name[
msg_class, field_name = _list_type_to_class_and_field[
type(value[0]) if len(value) > 0 else int
]
arg[field_name] = msg_class(vals=value)
else:
arg[_python_type_to_field_name[type(value)]] = value
arg[_type_to_field[type(value)]] = value
return Value(**arg)


Expand Down Expand Up @@ -573,3 +587,135 @@ def named_values_from_proto(
) -> Dict[str, typing.Value]:
"""Deserialize named values from ProtoBuf."""
return {name: value_from_proto(value) for name, value in named_values_proto.items()}


# === Record messages ===


T = TypeVar("T")


def _record_value_to_proto(
value: Any, allowed_types: List[type], proto_class: Type[T]
) -> T:
"""Serialize `*RecordValue` to ProtoBuf."""
arg = {}
for t in allowed_types:
# Single element
# Note: `isinstance(False, int) == True`.
if type(value) == t: # pylint: disable=C0123
arg[_type_to_field[t]] = value
return proto_class(**arg)
# List
if isinstance(value, list) and all(isinstance(item, t) for item in value):
list_class, field_name = _list_type_to_class_and_field[t]
arg[field_name] = list_class(vals=value)
return proto_class(**arg)
# Invalid types
raise TypeError(
f"The type of the following value is not allowed "
f"in '{proto_class.__name__}':\n{value}"
)


def _record_value_from_proto(value_proto: Message) -> Any:
"""Deserialize `*RecordValue` from ProtoBuf."""
value_field = cast(str, value_proto.WhichOneof("value"))
if value_field.endswith("list"):
value = list(getattr(value_proto, value_field).vals)
else:
value = getattr(value_proto, value_field)
return value


def _record_value_dict_to_proto(
value_dict: Dict[str, Any], allowed_types: List[type], value_proto_class: Type[T]
) -> Dict[str, T]:
"""Serialize the record value dict to ProtoBuf."""

def proto(_v: Any) -> T:
return _record_value_to_proto(_v, allowed_types, value_proto_class)

return {k: proto(v) for k, v in value_dict.items()}


def _record_value_dict_from_proto(
value_dict_proto: MutableMapping[str, Any]
) -> Dict[str, Any]:
"""Deserialize the record value dict from ProtoBuf."""
return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()}


def array_to_proto(array: Array) -> ProtoArray:
"""Serialize Array to ProtoBuf."""
return ProtoArray(**vars(array))


def array_from_proto(array_proto: ProtoArray) -> Array:
"""Deserialize Array from ProtoBuf."""
return Array(
dtype=array_proto.dtype,
shape=list(array_proto.shape),
stype=array_proto.stype,
data=array_proto.data,
)


def parameters_record_to_proto(record: ParametersRecord) -> ProtoParametersRecord:
"""Serialize ParametersRecord to ProtoBuf."""
return ProtoParametersRecord(
data_keys=record.data.keys(),
data_values=map(array_to_proto, record.data.values()),
)


def parameters_record_from_proto(
record_proto: ProtoParametersRecord,
) -> ParametersRecord:
"""Deserialize ParametersRecord from ProtoBuf."""
return ParametersRecord(
array_dict=OrderedDict(
zip(record_proto.data_keys, map(array_from_proto, record_proto.data_values))
),
keep_input=False,
)


def metrics_record_to_proto(record: MetricsRecord) -> ProtoMetricsRecord:
"""Serialize MetricsRecord to ProtoBuf."""
return ProtoMetricsRecord(
data=_record_value_dict_to_proto(
record.data, [float, int], ProtoMetricsRecordValue
)
)


def metrics_record_from_proto(record_proto: ProtoMetricsRecord) -> MetricsRecord:
"""Deserialize MetricsRecord from ProtoBuf."""
return MetricsRecord(
metrics_dict=cast(
Dict[str, typing.MetricsRecordValues],
_record_value_dict_from_proto(record_proto.data),
),
keep_input=False,
)


def configs_record_to_proto(record: ConfigsRecord) -> ProtoConfigsRecord:
"""Serialize ConfigsRecord to ProtoBuf."""
return ProtoConfigsRecord(
data=_record_value_dict_to_proto(
record.data, [int, float, bool, str, bytes], ProtoConfigsRecordValue
)
)


def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord:
"""Deserialize ConfigsRecord from ProtoBuf."""
return ConfigsRecord(
configs_dict=cast(
Dict[str, typing.ConfigsRecordValues],
_record_value_dict_from_proto(record_proto.data),
),
keep_input=False,
)
Loading