Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(framework) Include support for lists in common.Scalar #4379

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/py/flwr/client/mod/centraldp_mods.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ def fixedclipping_mod(
f" the server side."
)

clipping_norm = float(fit_ins.config[KEY_CLIPPING_NORM])
clipping_norm = fit_ins.config[KEY_CLIPPING_NORM]
if not isinstance(clipping_norm, float):
raise ValueError(f"{KEY_CLIPPING_NORM} should be a float value.")

server_to_client_params = parameters_to_ndarrays(fit_ins.parameters)

# Call inner app
Expand Down Expand Up @@ -124,9 +127,11 @@ def adaptiveclipping_mod(
f"DifferentialPrivacyClientSideFixedClipping wrapper at"
f" the server side."
)
if not isinstance(fit_ins.config[KEY_CLIPPING_NORM], float):

clipping_norm = fit_ins.config[KEY_CLIPPING_NORM]
if not isinstance(clipping_norm, float):
raise ValueError(f"{KEY_CLIPPING_NORM} should be a float value.")
clipping_norm = float(fit_ins.config[KEY_CLIPPING_NORM])

server_to_client_params = parameters_to_ndarrays(fit_ins.parameters)

# Call inner app
Expand Down
49 changes: 32 additions & 17 deletions src/py/flwr/common/recordset_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
Parameters,
Scalar,
Status,
Value,
)

EMPTY_TENSOR_KEY = "_empty"
Expand Down Expand Up @@ -129,15 +130,33 @@ def _check_mapping_from_recordscalartype_to_scalar(
record_data: Mapping[str, Union[ConfigsRecordValues, MetricsRecordValues]]
) -> dict[str, Scalar]:
"""Check mapping `common.*RecordValues` into `common.Scalar` is possible."""
for value in record_data.values():
if not isinstance(value, get_args(Scalar)):

# Filter out any list types
def is_valid(__v: Value) -> None:
"""Check if value is of expected type."""
if not isinstance(__v, get_args(Value)):
raise TypeError(
"There is not a 1:1 mapping between `common.Scalar` types and those "
"supported in `common.ConfigsRecordValues` or "
"`common.ConfigsRecordValues`. Consider casting your values to a type "
"supported by the `common.RecordSet` infrastructure. "
f"You used type: {type(value)}"
"Not all values are of valid type."
f" Expected `{Value}` but `{type(__v)}` was passed."
)

for value in record_data.values():

if isinstance(value, list):

if len(value) > 0:
is_valid(value[0])
# all elements in the list must be of the same valid type
# this is needed for protobuf
value_type = type(value[0])
if not all(isinstance(v, value_type) for v in value):
raise TypeError(
"All values in a list must be of the same valid type. "
f"One of {Value}."
)
else:
is_valid(value)

return cast(dict[str, Scalar], record_data)


Expand Down Expand Up @@ -171,9 +190,7 @@ def _fit_or_evaluate_ins_to_recordset(
parametersrecord = parameters_to_parametersrecord(ins.parameters, keep_input)
recordset.parameters_records[f"{ins_str}.parameters"] = parametersrecord

recordset.configs_records[f"{ins_str}.config"] = ConfigsRecord(
ins.config # type: ignore
)
recordset.configs_records[f"{ins_str}.config"] = ConfigsRecord(ins.config)

return recordset

Expand Down Expand Up @@ -239,9 +256,7 @@ def fitres_to_recordset(fitres: FitRes, keep_input: bool) -> RecordSet:

res_str = "fitres"

recordset.configs_records[f"{res_str}.metrics"] = ConfigsRecord(
fitres.metrics # type: ignore
)
recordset.configs_records[f"{res_str}.metrics"] = ConfigsRecord(fitres.metrics)
recordset.metrics_records[f"{res_str}.num_examples"] = MetricsRecord(
{"num_examples": fitres.num_examples},
)
Expand Down Expand Up @@ -311,7 +326,7 @@ def evaluateres_to_recordset(evaluateres: EvaluateRes) -> RecordSet:

# metrics
recordset.configs_records[f"{res_str}.metrics"] = ConfigsRecord(
evaluateres.metrics, # type: ignore
evaluateres.metrics,
)

# status
Expand All @@ -336,7 +351,7 @@ def getparametersins_to_recordset(getparameters_ins: GetParametersIns) -> Record
recordset = RecordSet()

recordset.configs_records["getparametersins.config"] = ConfigsRecord(
getparameters_ins.config, # type: ignore
getparameters_ins.config,
)
return recordset

Expand Down Expand Up @@ -386,7 +401,7 @@ def getpropertiesins_to_recordset(getpropertiesins: GetPropertiesIns) -> RecordS
"""Construct a RecordSet from a GetPropertiesRes object."""
recordset = RecordSet()
recordset.configs_records["getpropertiesins.config"] = ConfigsRecord(
getpropertiesins.config, # type: ignore
getpropertiesins.config,
)
return recordset

Expand All @@ -408,7 +423,7 @@ def getpropertiesres_to_recordset(getpropertiesres: GetPropertiesRes) -> RecordS
recordset = RecordSet()
res_str = "getpropertiesres"
recordset.configs_records[f"{res_str}.properties"] = ConfigsRecord(
getpropertiesres.properties, # type: ignore
getpropertiesres.properties,
)
# status
recordset = _embed_status_into_recordset(
Expand Down
11 changes: 8 additions & 3 deletions src/py/flwr/common/recordset_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ def get_ndarrays() -> NDArrays:

def _get_valid_fitins() -> FitIns:
arrays = get_ndarrays()
return FitIns(parameters=ndarrays_to_parameters(arrays), config={"a": 1.0, "b": 0})
return FitIns(
parameters=ndarrays_to_parameters(arrays),
config={"a": 1.0, "b": 0, "c": [1.0, 2.0, 3.0]},
)


def _get_valid_fitins_with_empty_ndarrays() -> FitIns:
Expand All @@ -82,7 +85,7 @@ def _get_valid_fitins_with_empty_ndarrays() -> FitIns:
def _get_valid_fitres() -> FitRes:
"""Returnn Valid parameters but potentially invalid config."""
arrays = get_ndarrays()
metrics: dict[str, Scalar] = {"a": 1.0, "b": 0}
metrics: dict[str, Scalar] = {"a": 1.0, "b": 0, "cc": [1.0, 2.0, 3.0]}
return FitRes(
parameters=ndarrays_to_parameters(arrays),
num_examples=1,
Expand All @@ -98,7 +101,7 @@ def _get_valid_evaluateins() -> EvaluateIns:

def _get_valid_evaluateres() -> EvaluateRes:
"""Return potentially invalid config."""
metrics: dict[str, Scalar] = {"a": 1.0, "b": 0}
metrics: dict[str, Scalar] = {"a": 1.0, "b": 0, "c": [1.0, 2.0, 3.0]}
return EvaluateRes(
num_examples=1,
loss=0.1,
Expand All @@ -112,6 +115,7 @@ def _get_valid_getparametersins() -> GetParametersIns:
"a": 1.0,
"b": 3,
"c": True,
"d": [True, False, True],
} # valid since both Ins/Res communicate over ConfigsRecord

return GetParametersIns(config_dict)
Expand All @@ -135,6 +139,7 @@ def _get_valid_getpropertiesres() -> GetPropertiesRes:
"a": 1.0,
"b": 3,
"c": True,
"d": [1, 2, 3],
} # valid since both Ins/Res communicate over ConfigsRecord

return GetPropertiesRes(
Expand Down
10 changes: 3 additions & 7 deletions src/py/flwr/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,9 @@
# ProtoBuf considers to be "Scalar Value Types", even though some of them arguably do
# not conform to other definitions of what a scalar is. Source:
# https://developers.google.com/protocol-buffers/docs/overview#scalar
Scalar = Union[bool, bytes, float, int, str]
Value = Union[
bool,
bytes,
float,
int,
str,
Value = Union[bool, bytes, float, int, str]
Scalar = Union[
Value,
list[bool],
list[bytes],
list[float],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def __init__(self, state: RecordSet) -> None:

def get_properties(self, config: Config) -> dict[str, Scalar]:
"""Return properties by doing a simple calculation."""
result = float(config["factor"]) * pi
factor = config["factor"]
assert isinstance(factor, float)
result = factor * pi

# store something in context
self.client_state.configs_records["result"] = ConfigsRecord({"result": result})
Expand Down Expand Up @@ -86,7 +88,7 @@ def backend_build_process_and_termination(
def _create_message_and_context() -> tuple[Message, Context, float]:

# Construct a Message
mult_factor = 2024
mult_factor = 2024.0
run_id = 0
getproperties_ins = GetPropertiesIns(config={"factor": mult_factor})
recordset = getpropertiesins_to_recordset(getproperties_ins)
Expand Down
6 changes: 4 additions & 2 deletions src/py/flwr/server/superlink/fleet/vce/vce_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def __init__(self, state: RecordSet) -> None:

def get_properties(self, config: Config) -> dict[str, Scalar]:
"""Return properties by doing a simple calculation."""
result = float(config["factor"]) * pi
factor = config["factor"]
assert isinstance(factor, float)
result = factor * pi

# store something in context
self.client_state.configs_records["result"] = ConfigsRecord({"result": result})
Expand Down Expand Up @@ -137,7 +139,7 @@ def register_messages_into_state(
for i in range(num_messages):
dst_node_id = next(nodes_cycle)
# Construct a Message
mult_factor = 2024 + i
mult_factor = 2024.0 + i
getproperties_ins = GetPropertiesIns(config={"factor": mult_factor})
recordset = getpropertiesins_to_recordset(getproperties_ins)
message = Message(
Expand Down