From 255925938bd7c559af3f7f1ad7b363d34aab28d3 Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 22 Jan 2024 17:02:31 +0000 Subject: [PATCH] Improve records type checking (#2838) --- src/py/flwr/common/configsrecord.py | 13 +++++++++++-- src/py/flwr/common/metricsrecord.py | 13 +++++++++++-- src/py/flwr/common/recordset_test.py | 10 ++++++++++ 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/common/configsrecord.py b/src/py/flwr/common/configsrecord.py index 332269503ac0..b0480841e06c 100644 --- a/src/py/flwr/common/configsrecord.py +++ b/src/py/flwr/common/configsrecord.py @@ -87,8 +87,17 @@ def is_valid(value: ConfigsScalar) -> None: # 1s to check 10M element list on a M2 Pro # In such settings, you'd be better of treating such config as # an array and pass it to a ParametersRecord. - for list_value in value: - is_valid(list_value) + # Empty lists are valid + if len(value) > 0: + is_valid(value[0]) + # all elements in the list must be of the same valid type + # this is needed for protobuf + value_type = type(value[0]) + if not all(isinstance(v, value_type) for v in value): + raise TypeError( + "All values in a list must be of the same valid type. " + f"One of {ConfigsScalar}." + ) else: is_valid(value) diff --git a/src/py/flwr/common/metricsrecord.py b/src/py/flwr/common/metricsrecord.py index ecb8eff830ab..e70b0cb31d55 100644 --- a/src/py/flwr/common/metricsrecord.py +++ b/src/py/flwr/common/metricsrecord.py @@ -87,8 +87,17 @@ def is_valid(value: MetricsScalar) -> None: # 1s to check 10M element list on a M2 Pro # In such settings, you'd be better of treating such metric as # an array and pass it to a ParametersRecord. - for list_value in value: - is_valid(list_value) + # Empty lists are valid + if len(value) > 0: + is_valid(value[0]) + # all elements in the list must be of the same valid type + # this is needed for protobuf + value_type = type(value[0]) + if not all(isinstance(v, value_type) for v in value): + raise TypeError( + "All values in a list must be of the same valid type. " + f"One of {MetricsScalar}." + ) else: is_valid(value) diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py index 0e4c351647da..83e1e4595f1d 100644 --- a/src/py/flwr/common/recordset_test.py +++ b/src/py/flwr/common/recordset_test.py @@ -162,6 +162,7 @@ def test_set_parameters_with_incorrect_types( (str, lambda x: float(x.flatten()[0])), # str: float (str, lambda x: x.flatten().astype("int").tolist()), # str: List[int] (str, lambda x: x.flatten().astype("float").tolist()), # str: List[float] + (str, lambda x: []), # str: empty list ], ) def test_set_metrics_to_metricsrecord_with_correct_types( @@ -203,6 +204,10 @@ def test_set_metrics_to_metricsrecord_with_correct_types( str, lambda x: [{str(v): v for v in x.flatten()}], ), # str: List[dict[str: float]] (supported: unsupported) + ( + str, + lambda x: [1, 2.0, 3.0, 4], + ), # str: List[mixing valid types] (supported: unsupported) ( int, lambda x: x.flatten().tolist(), @@ -278,6 +283,7 @@ def test_set_metrics_to_metricsrecord_with_and_without_keeping_input( (str, lambda x: x.flatten().astype("float").tolist()), # str: List[float] (str, lambda x: x.flatten().astype("bool").tolist()), # str: List[bool] (str, lambda x: [x.flatten().tobytes()]), # str: List[bytes] + (str, lambda x: []), # str: empyt list ], ) def test_set_configs_to_configsrecord_with_correct_types( @@ -310,6 +316,10 @@ def test_set_configs_to_configsrecord_with_correct_types( str, lambda x: [{str(v): v for v in x.flatten()}], ), # str: List[dict[str: float]] (supported: unsupported) + ( + str, + lambda x: [1, 2.0, 3.0, 4], + ), # str: List[mixing valid types] (supported: unsupported) ( int, lambda x: x.flatten().tolist(),