diff --git a/src/py/flwr/common/configsrecord.py b/src/py/flwr/common/configsrecord.py index 494cb88586ac..332269503ac0 100644 --- a/src/py/flwr/common/configsrecord.py +++ b/src/py/flwr/common/configsrecord.py @@ -25,7 +25,6 @@ class ConfigsRecord: """Configs record.""" - keep_input: bool data: Dict[str, ConfigsRecordValues] = field(default_factory=dict) def __init__( @@ -47,12 +46,13 @@ def __init__( to True, the data is duplicated in memory. If memory is a concern, set it to False. """ - self.keep_input = keep_input self.data = {} if configs_dict: - self.set_configs(configs_dict) + self.set_configs(configs_dict, keep_input=keep_input) - def set_configs(self, configs_dict: Dict[str, ConfigsRecordValues]) -> None: + def set_configs( + self, configs_dict: Dict[str, ConfigsRecordValues], keep_input: bool = True + ) -> None: """Add configs to the record. Parameters @@ -61,6 +61,11 @@ def set_configs(self, configs_dict: Dict[str, ConfigsRecordValues]) -> None: A dictionary that stores basic types (i.e. `str`,`int`, `float`, `bytes` as defined in `ConfigsRecordValues`) and list of such types (see `ConfigsScalarList`). + keep_input : bool (default: True) + A boolean indicating whether config passed should be deleted from the input + dictionary immediately after adding them to the record. When set + to True, the data is duplicated in memory. If memory is a concern, set + it to False. """ if any(not isinstance(k, str) for k in configs_dict.keys()): raise TypeError(f"Not all keys are of valid type. Expected {str}") @@ -88,7 +93,7 @@ def is_valid(value: ConfigsScalar) -> None: is_valid(value) # Add configs to record - if self.keep_input: + if keep_input: # Copy self.data = configs_dict.copy() else: @@ -96,3 +101,7 @@ def is_valid(value: ConfigsScalar) -> None: for key in list(configs_dict.keys()): self.data[key] = configs_dict[key] del configs_dict[key] + + def __getitem__(self, key: str) -> ConfigsRecordValues: + """Retrieve an element stored in record.""" + return self.data[key] diff --git a/src/py/flwr/common/metricsrecord.py b/src/py/flwr/common/metricsrecord.py index 68eca732efa2..d66a4454635a 100644 --- a/src/py/flwr/common/metricsrecord.py +++ b/src/py/flwr/common/metricsrecord.py @@ -25,7 +25,6 @@ class MetricsRecord: """Metrics record.""" - keep_input: bool data: Dict[str, MetricsRecordValues] = field(default_factory=dict) def __init__( @@ -46,12 +45,13 @@ def __init__( to True, the data is duplicated in memory. If memory is a concern, set it to False. """ - self.keep_input = keep_input self.data = {} if metrics_dict: - self.set_metrics(metrics_dict) + self.set_metrics(metrics_dict, keep_input=keep_input) - def set_metrics(self, metrics_dict: Dict[str, MetricsRecordValues]) -> None: + def set_metrics( + self, metrics_dict: Dict[str, MetricsRecordValues], keep_input: bool = True + ) -> None: """Add metrics to the record. Parameters @@ -59,6 +59,11 @@ def set_metrics(self, metrics_dict: Dict[str, MetricsRecordValues]) -> None: metrics_dict : Dict[str, MetricsRecordValues] A dictionary that stores basic types (i.e. `int`, `float` as defined in `MetricsScalar`) and list of such types (see `MetricsScalarList`). + keep_input : bool (default: True) + A boolean indicating whether metrics should be deleted from the input + dictionary immediately after adding them to the record. When set + to True, the data is duplicated in memory. If memory is a concern, set + it to False. """ if any(not isinstance(k, str) for k in metrics_dict.keys()): raise TypeError(f"Not all keys are of valid type. Expected {str}.") @@ -86,7 +91,7 @@ def is_valid(value: MetricsScalar) -> None: is_valid(value) # Add metrics to record - if self.keep_input: + if keep_input: # Copy self.data = metrics_dict.copy() else: @@ -94,3 +99,7 @@ def is_valid(value: MetricsScalar) -> None: for key in list(metrics_dict.keys()): self.data[key] = metrics_dict[key] del metrics_dict[key] + + def __getitem__(self, key: str) -> MetricsRecordValues: + """Retrieve an element stored in record.""" + return self.data[key] diff --git a/src/py/flwr/common/parametersrecord.py b/src/py/flwr/common/parametersrecord.py index 3d40c0488baa..ef02a0789ddf 100644 --- a/src/py/flwr/common/parametersrecord.py +++ b/src/py/flwr/common/parametersrecord.py @@ -59,7 +59,6 @@ class ParametersRecord: PyTorch's state_dict, but holding serialised tensors instead. """ - keep_input: bool data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array]) def __init__( @@ -82,25 +81,29 @@ def __init__( parameters after adding it to the record, set this flag to True. When set to True, the data is duplicated in memory. """ - self.keep_input = keep_input self.data = OrderedDict() if array_dict: - self.set_parameters(array_dict) + self.set_parameters(array_dict, keep_input=keep_input) - def set_parameters(self, array_dict: OrderedDict[str, Array]) -> None: + def set_parameters( + self, array_dict: OrderedDict[str, Array], keep_input: bool = False + ) -> None: """Add parameters to record. Parameters ---------- array_dict : OrderedDict[str, Array] A dictionary that stores serialized array-like or tensor-like objects. + keep_input : bool (default: False) + A boolean indicating whether parameters should be deleted from the input + dictionary immediately after adding them to the record. """ if any(not isinstance(k, str) for k in array_dict.keys()): raise TypeError(f"Not all keys are of valid type. Expected {str}") if any(not isinstance(v, Array) for v in array_dict.values()): raise TypeError(f"Not all values are of valid type. Expected {Array}") - if self.keep_input: + if keep_input: # Copy self.data = OrderedDict(array_dict) else: @@ -108,3 +111,7 @@ def set_parameters(self, array_dict: OrderedDict[str, Array]) -> None: for key in list(array_dict.keys()): self.data[key] = array_dict[key] del array_dict[key] + + def __getitem__(self, key: str) -> Array: + """Retrieve an element stored in record.""" + return self.data[key] diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py index 3f0917d75cf5..b2f53ce43303 100644 --- a/src/py/flwr/common/recordset_test.py +++ b/src/py/flwr/common/recordset_test.py @@ -111,7 +111,7 @@ def test_set_parameters_while_keeping_intputs() -> None: array_dict = OrderedDict( {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())} ) - p_record.set_parameters(array_dict) + p_record.set_parameters(array_dict, keep_input=True) # Creating a second parametersrecord passing the same `array_dict` (not erased) p_record_2 = ParametersRecord(array_dict) @@ -253,7 +253,7 @@ def test_set_metrics_to_metricsrecord_with_and_without_keeping_input( my_metrics_copy = my_metrics.copy() # Add metric - m_record.set_metrics(my_metrics) + m_record.set_metrics(my_metrics, keep_input=keep_input) # Check metrics are actually added # Check that input dict has been emptied when enabled such behaviour