From 539b5a1057cdd0e981c171bbad1d4da3b3527320 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Thu, 25 Jan 2024 17:43:22 +0100 Subject: [PATCH] Fix ParametersRecord<>Parameters conversion (#2852) Co-authored-by: jafermarq --- src/py/flwr/common/recordset_compat.py | 23 +++++++++++++---------- src/py/flwr/common/recordset_test.py | 23 ++++++++++++++++++++--- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/src/py/flwr/common/recordset_compat.py b/src/py/flwr/common/recordset_compat.py index c45f7fcd9fb8..65a7181f219a 100644 --- a/src/py/flwr/common/recordset_compat.py +++ b/src/py/flwr/common/recordset_compat.py @@ -40,20 +40,22 @@ def parametersrecord_to_parameters( - record: ParametersRecord, keep_input: bool = False + record: ParametersRecord, keep_input: bool = True ) -> Parameters: """Convert ParameterRecord to legacy Parameters. - Warning: Because `Arrays` in `ParametersRecord` encode more information of the + Warnings + -------- + Because `Arrays` in `ParametersRecord` encode more information of the array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it might not be possible to reconstruct such data structures from `Parameters` objects - alone. Additional information or metadta must be provided from elsewhere. + alone. Additional information or metadata must be provided from elsewhere. Parameters ---------- record : ParametersRecord The record to be conveted into Parameters. - keep_input : bool (default: False) + keep_input : bool (default: True) A boolean indicating whether entries in the record should be deleted from the input dictionary immediately after adding them to the record. """ @@ -74,7 +76,8 @@ def parametersrecord_to_parameters( def parameters_to_parametersrecord( - parameters: Parameters, keep_input: bool = False + parameters: Parameters, + keep_input: bool = True, ) -> ParametersRecord: """Convert legacy Parameters into a single ParametersRecord. @@ -86,7 +89,7 @@ def parameters_to_parametersrecord( ---------- parameters : Parameters Parameters object to be represented as a ParametersRecord. - keep_input : bool (default: False) + keep_input : bool (default=True) A boolean indicating whether parameters should be deleted from the input Parameters object (i.e. a list of serialized NumPy arrays) immediately after adding them to the record. @@ -96,17 +99,17 @@ def parameters_to_parametersrecord( p_record = ParametersRecord() num_arrays = len(parameters.tensors) + ordered_dict = OrderedDict() for idx in range(num_arrays): if keep_input: tensor = parameters.tensors[idx] else: tensor = parameters.tensors.pop(0) - p_record.set_parameters( - OrderedDict( - {str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])} - ) + ordered_dict[str(idx)] = Array( + data=tensor, dtype="", stype=tensor_type, shape=[] ) + p_record.set_parameters(ordered_dict, keep_input=keep_input) return p_record diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py index e1825eaeef14..10fe85a56ecc 100644 --- a/src/py/flwr/common/recordset_test.py +++ b/src/py/flwr/common/recordset_test.py @@ -87,15 +87,32 @@ def test_parameters_to_array_and_back() -> None: assert np.array_equal(ndarray, ndarray_) -def test_parameters_to_parametersrecord_and_back() -> None: +@pytest.mark.parametrize( + "keep_input", + [False, True], +) +def test_parameters_to_parametersrecord_and_back(keep_input: bool) -> None: """Test conversion between legacy Parameters and ParametersRecords.""" ndarrays = get_ndarrays() parameters = ndarrays_to_parameters(ndarrays) - params_record = parameters_to_parametersrecord(parameters=parameters) + params_record = parameters_to_parametersrecord( + parameters=parameters, keep_input=keep_input + ) + + if keep_input: + # Verify inputed parameters are indeed as originally passed + assert parameters == ndarrays_to_parameters(ndarrays) + else: + # Verify tensors have been erased + assert len(parameters.tensors) == 0 + + parameters_ = parametersrecord_to_parameters(params_record, keep_input=keep_input) - parameters_ = parametersrecord_to_parameters(params_record) + if not keep_input: + # Verify Arrays in record have been erased + assert len(params_record.data) == 0 ndarrays_ = parameters_to_ndarrays(parameters=parameters_)