Skip to content

Commit

Permalink
Merge branch 'main' into driver-retry
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Jan 22, 2024
2 parents c254177 + 2559259 commit 957a2d0
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
13 changes: 11 additions & 2 deletions src/py/flwr/common/configsrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,17 @@ def is_valid(value: ConfigsScalar) -> None:
# 1s to check 10M element list on a M2 Pro
# In such settings, you'd be better of treating such config as
# an array and pass it to a ParametersRecord.
for list_value in value:
is_valid(list_value)
# Empty lists are valid
if len(value) > 0:
is_valid(value[0])
# all elements in the list must be of the same valid type
# this is needed for protobuf
value_type = type(value[0])
if not all(isinstance(v, value_type) for v in value):
raise TypeError(
"All values in a list must be of the same valid type. "
f"One of {ConfigsScalar}."
)
else:
is_valid(value)

Expand Down
13 changes: 11 additions & 2 deletions src/py/flwr/common/metricsrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,17 @@ def is_valid(value: MetricsScalar) -> None:
# 1s to check 10M element list on a M2 Pro
# In such settings, you'd be better of treating such metric as
# an array and pass it to a ParametersRecord.
for list_value in value:
is_valid(list_value)
# Empty lists are valid
if len(value) > 0:
is_valid(value[0])
# all elements in the list must be of the same valid type
# this is needed for protobuf
value_type = type(value[0])
if not all(isinstance(v, value_type) for v in value):
raise TypeError(
"All values in a list must be of the same valid type. "
f"One of {MetricsScalar}."
)
else:
is_valid(value)

Expand Down
10 changes: 10 additions & 0 deletions src/py/flwr/common/recordset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def test_set_parameters_with_incorrect_types(
(str, lambda x: float(x.flatten()[0])), # str: float
(str, lambda x: x.flatten().astype("int").tolist()), # str: List[int]
(str, lambda x: x.flatten().astype("float").tolist()), # str: List[float]
(str, lambda x: []), # str: empty list
],
)
def test_set_metrics_to_metricsrecord_with_correct_types(
Expand Down Expand Up @@ -203,6 +204,10 @@ def test_set_metrics_to_metricsrecord_with_correct_types(
str,
lambda x: [{str(v): v for v in x.flatten()}],
), # str: List[dict[str: float]] (supported: unsupported)
(
str,
lambda x: [1, 2.0, 3.0, 4],
), # str: List[mixing valid types] (supported: unsupported)
(
int,
lambda x: x.flatten().tolist(),
Expand Down Expand Up @@ -278,6 +283,7 @@ def test_set_metrics_to_metricsrecord_with_and_without_keeping_input(
(str, lambda x: x.flatten().astype("float").tolist()), # str: List[float]
(str, lambda x: x.flatten().astype("bool").tolist()), # str: List[bool]
(str, lambda x: [x.flatten().tobytes()]), # str: List[bytes]
(str, lambda x: []), # str: empyt list
],
)
def test_set_configs_to_configsrecord_with_correct_types(
Expand Down Expand Up @@ -310,6 +316,10 @@ def test_set_configs_to_configsrecord_with_correct_types(
str,
lambda x: [{str(v): v for v in x.flatten()}],
), # str: List[dict[str: float]] (supported: unsupported)
(
str,
lambda x: [1, 2.0, 3.0, 4],
), # str: List[mixing valid types] (supported: unsupported)
(
int,
lambda x: x.flatten().tolist(),
Expand Down

0 comments on commit 957a2d0

Please sign in to comment.